GAN生成对抗网络小记
生成对抗网络(Generative Adversarial Network, GAN)是一个。GAN 训练是一个交替优化的过程,通常采用**梯度下降(SGD, Adam)**来更新。GAN 目前广泛用于图像生成、风格转换、语音合成等领域,是最重要的生成模型之一。生成的数据与真实数据无法区分。
生成对抗网络(GAN)深入解析:数学原理与优化
生成对抗网络(Generative Adversarial Network, GAN)是一个基于博弈论的深度学习框架,通过生成器(G)和判别器(D)之间的对抗训练,生成高度逼真的数据。其核心思想是让 GGG 生成伪造数据以欺骗 DDD,而 DDD 则努力分辨真实数据与伪造数据。GAN 在理论上可以看作一个极小极大(Minimax)优化问题。
1. GAN 的数学公式
1.1 生成器与判别器的定义
- 生成器 G(z)G(z)G(z): 输入一个随机噪声 z∼pz(z)z \sim p_z(z)z∼pz(z)(通常为高斯分布或均匀分布),输出一个生成样本 G(z)G(z)G(z),试图让这个样本与真实样本相似。
- 判别器 D(x)D(x)D(x): 输入一个样本 xxx,输出一个介于 0 和 1 之间的概率 D(x)D(x)D(x),表示样本是真实数据的概率。
1.2 GAN 的目标函数
GAN 采用极小极大(Minimax)损失函数,其目标是让生成器尽可能生成真实数据,而判别器尽可能区分真实数据和伪造数据:
minGmaxDV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
其中:
- pdata(x)p_{\text{data}}(x)pdata(x) 是真实数据的分布,
- pz(z)p_z(z)pz(z) 是输入噪声的分布,
- D(x)D(x)D(x) 试图最大化分类准确率,
- G(z)G(z)G(z) 试图最小化判别器的分类能力。
1.3 最优判别器
如果固定生成器 GGG,则判别器 DDD 需要最大化目标函数:
V(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))] V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] V(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
可以证明,最优判别器 D∗(x)D^*(x)D∗(x) 的形式是:
D∗(x)=pdata(x)pdata(x)+pg(x) D^*(x) = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_g(x)} D∗(x)=pdata(x)+pg(x)pdata(x)
其中,pg(x)p_g(x)pg(x) 是由生成器 G(z)G(z)G(z) 生成的数据的分布。
2. 训练过程
GAN 训练是一个交替优化的过程,通常采用**梯度下降(SGD, Adam)**来更新 GGG 和 DDD:
-
训练判别器 DDD:
- 取真实数据 x∼pdata(x)x \sim p_{\text{data}}(x)x∼pdata(x),计算 D(x)D(x)D(x) 并最大化 logD(x)\log D(x)logD(x)。
- 取生成数据 G(z)G(z)G(z) 使 D(G(z))D(G(z))D(G(z)) 尽可能小,即最小化 log(1−D(G(z)))\log(1 - D(G(z)))log(1−D(G(z)))。
- 更新判别器参数 θD\theta_DθD 以提高区分能力。
-
训练生成器 GGG:
- 生成数据 G(z)G(z)G(z),然后让判别器对其分类。
- 生成器希望让 D(G(z))D(G(z))D(G(z)) 输出接近 1(让判别器误判)。
- 生成器的优化目标是最小化 log(1−D(G(z)))\log(1 - D(G(z)))log(1−D(G(z))),或者直接最大化 logD(G(z))\log D(G(z))logD(G(z))(这被称为改进版 GAN 损失)。
-
循环迭代,直到 GGG 生成的数据与真实数据无法区分。
3. 训练挑战
3.1 模式崩溃(Mode Collapse)
- 生成器可能学会只生成一小部分数据,而非整个数据分布。
- 例如,GGG 只生成某一类图像,导致 DDD 很容易识别 GGG 的模式。
解决方案:
- Minibatch Discrimination:让 DDD 学习样本之间的多样性,防止 GGG 只生成少量模式。
- Unrolled GAN:考虑 DDD 在未来几步更新中的影响,使 GGG 不会局部最优。
3.2 训练不稳定
- GAN 训练是非凸优化问题,可能导致梯度消失或振荡。
- 训练过程中,GGG 和 DDD 的能力必须匹配,否则其中一方会迅速胜出,导致训练失败。
解决方案:
- 使用 WGAN(Wasserstein GAN):WGAN 使用 Wasserstein 距离替代 KL 散度,使训练更加稳定。
- 调整判别器与生成器的更新频率:例如,训练判别器多步,再训练一次生成器。
3.3 梯度消失
- 当 DDD 过强时,GGG 可能学不到有意义的梯度。
- 由于 log(1−D(G(z)))\log(1 - D(G(z)))log(1−D(G(z))) 在 D(G(z))D(G(z))D(G(z)) 远离 1 时梯度趋于 0,GGG 可能难以更新。
解决方案:
- 使用改进损失 maxGlogD(G(z))\max_G \log D(G(z))maxGlogD(G(z)) 以提供更好的梯度信号。
- 使用 Batch Normalization 或 调整学习率。
4. GAN 的变种
- DCGAN(深度卷积 GAN):使用卷积神经网络(CNN),提升图像质量。
- CGAN(条件 GAN):在 GGG 和 DDD 额外输入条件信息(如类别标签)。
- WGAN(Wasserstein GAN):使用 Wasserstein 距离替代交叉熵,提高训练稳定性。
- StyleGAN:用于高分辨率人脸生成,生成风格可控。
- CycleGAN:用于图像到图像的转换(如将马变成斑马)。
5. 总结
- GAN 通过博弈思想训练生成模型,使得生成数据逐步逼近真实数据分布。
- 核心数学公式:GAN 是一个极小极大优化问题,目标是让 GGG 生成的分布 pg(x)p_g(x)pg(x) 逼近真实分布 pdata(x)p_{\text{data}}(x)pdata(x)。
- 训练时,GGG 和 DDD 交替优化,使得 GGG 学会欺骗 DDD,最终生成高质量数据。
- 存在模式崩溃、梯度消失等问题,改进版本如 WGAN 和 StyleGAN 解决了一些问题。
GAN 目前广泛用于图像生成、风格转换、语音合成等领域,是最重要的生成模型之一。
更多推荐
所有评论(0)