PythonOT/POT 快速入门指南:最优传输与机器学习实践

为什么需要最优传输?

最优传输(Optimal Transport, OT)是1781年由Gaspard Monge提出的数学问题,旨在寻找在分布之间转移质量的最有效方式。在机器学习领域,最优传输已经成为衡量分布相似性和进行知识迁移的强大工具。

最优传输的核心价值

最优传输的核心在于两个关键输出:

  1. 最优值(Wasserstein距离):衡量分布之间的相似性
  2. 最优映射(Monge映射或OT矩阵):发现分布之间的对应关系
Wasserstein距离的优势

与传统f-散度(如KL散度、JS散度)相比,Wasserstein距离具有独特优势:

  • 能够处理支撑集不重叠的分布
  • 提供有意义的次梯度
  • 在数据科学应用中计算友好

这些特性使其在GAN训练、判别子空间发现、文档嵌入相似性比较等场景中表现出色。

映射估计的应用

OT矩阵本身提供了样本间的对应关系,这种无监督的对应关系发现能力在以下场景非常有用:

  • 图像间的颜色迁移
  • 领域自适应问题
  • 词嵌入空间的语言对齐(通过Gromov-Wasserstein扩展)

PythonOT/POT工具包概览

PythonOT/POT专为机器学习场景中的最优传输问题而设计,提供了多种求解器的实现,旨在促进可重复研究和新算法开发。

适用场景

POT特别适合以下情况:

  • 需要精确OT解的研究工作
  • 需要灵活扩展的算法开发
  • 中等规模的数据集(样本量在数千级别)

不适用场景

POT在以下情况可能不是最佳选择:

  • 超大规模数据集(样本量超过数万)
  • 内存受限环境(OT问题需要O(n²)内存)
  • 实时性要求极高的应用

对于大规模问题,建议考虑使用GeomLoss等内存效率更高的实现,或者采用小批量Wasserstein距离近似方法。

基础OT问题求解

Kantorovich公式

离散分布的最优传输问题通常表述为:

γ* = argmin_{γ∈ℝ₊^{m×n}} ∑γ_{i,j}M_{i,j} s.t. γ1 = a; γᵀ1 = b; γ ≥ 0

其中:

  • M是度量成本矩阵
  • a和b是单纯形上的直方图(正值且和为1)

使用POT求解

POT提供了两种形式的函数:

  • 返回OT矩阵的函数(如ot.emd)
  • 返回最优值的函数(如ot.emd2)
# 计算OT矩阵
T = ot.emd(a, b, M)  # 精确线性规划

# 计算Wasserstein距离
W = ot.emd2(a, b, M)  # 直接返回最优值

POT使用网络单纯形法(C语言实现)求解,复杂度为O(n³),但实际效率较高。

特殊情况的优化

一维分布

对于一维样本,OT问题可在O(n log n)时间内解决:

# 一维OT矩阵
T_1d = ot.emd_1d(xs, xt, a, b)

# 一维Wasserstein距离
Wp = ot.wasserstein_1d(xs, xt, a, b, p=2)  # W2距离
高斯分布

对于高斯分布,存在闭式解:

# 计算高斯分布间的Bures-Wasserstein映射
A, b = ot.gaussian.bures_wasserstein_mapping(mu_s, mu_t, cov_s, cov_t)

正则化最优传输

正则化OT在计算和统计特性上都有优势,POT支持多种正则化形式。

熵正则化OT

最常用的正则化形式,由Marco Cuturi引入:

Ω(γ) = ∑γ_{i,j}log(γ_{i,j})

熵正则化使问题:

  1. 变得平滑
  2. 严格凸
  3. 有唯一解

解的形式为:γ_λ* = diag(u)Kdiag(v)

POT提供了多种Sinkhorn算法变体:

# 基础Sinkhorn算法
T_reg = ot.sinkhorn(a, b, M, reg=1.0)
算法选择建议
  1. 默认情况method='sinkhorn'
  2. 小正则化参数method='sinkhorn_stabilized'
  3. 数值稳定性要求高method='sinkhorn_log'
  4. 大规模问题method='greenkhorn'method='screenkhorn'

Sinkhorn散度

Genevay等人提出的Sinkhorn散度提供了快速可微的几何散度计算:

# 计算经验分布的Sinkhorn散度
div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, reg=1.0)

实践建议

  1. 数据预处理:确保输入分布a和b是归一化的直方图
  2. 成本矩阵选择:根据问题特性选择合适的距离度量
  3. 正则化参数:从小值开始逐步调整,平衡精度与计算效率
  4. 算法选择:根据问题规模和精度需求选择合适的求解器

通过合理使用PythonOT/POT,开发者可以在机器学习任务中高效地应用最优传输理论,解决分布比较和知识迁移等核心问题。

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐