1. GAN底层原理解析

1.1 核心思想

生成对抗网络(GAN)通过博弈论框架实现数据生成,其核心是生成器(Generator)与判别器(Discriminator)的对抗训练:

1)双网络架构
   - 生成器(G):将随机噪声z映射到数据空间,G(z) → 伪样本  
   - 判别器(D):区分真实数据(P_{data})与生成数据(P_{g}),输出概率值  
   - 目标函数(极小极大博弈):  
     $$ \min_G \max_D V(D,G) = \mathbb{E}_{x\sim p_{data}}[\log D(x)] + \mathbb{E}_{z\sim p_z}[\log(1-D(G(z)))] $$

2)纳什均衡
   - 理想状态:当 P_{g} = P_{data} 时达到平衡(D无法区分真伪)  
   - 数学证明:当G固定时,最优判别器  
     $$ D^*(x) = \frac{P_{data}(x)}{P_{data}(x) + P_g(x)} $$

3)训练动态
   - 判别器训练:最大化真实样本得分,最小化生成样本得分  
   - 生成器训练:欺骗判别器,使其对G(z)给出高评分  
   - 交替训练策略(关键实现细节)

1.2 数学推导

1)KL散度视角 
GAN的目标函数等价于最小化 P_{g} 与 P_{data} 的JS散度:  
$$ JSD(P_{data} \| P_g) = \frac{1}{2}KL(P_{data} \| \frac{P_{data}+P_g}{2}) + \frac{1}{2}KL(P_g \| \frac{P_{data}+P_g}{2}) $$

1.3 训练难点

1)模式崩溃(Mode Collapse)
   - 生成器仅生成有限模式样本  
   - 解决方案:Mini-batch Discrimination、Unrolled GAN  

2)梯度消失
   - 当D过于强大时,G的梯度趋近于零  
   - 改进目标函数:NS-GAN使用 -E[logD(G(z))]

3)评估指标
   - Inception Score (IS)、Fréchet Inception Distance (FID)  
   - 避免主观评价偏差

2. 基于PyTorch代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 超参数配置
latent_dim = 100          # 噪声向量维度
img_size = 28             # 图像尺寸
batch_size = 128          # 批大小
epochs = 50               # 训练轮数
lr = 0.0002               # 学习率
b1 = 0.5                  # Adam优化器参数
b2 = 0.999

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),                 # 转换为Tensor [0,1]
    transforms.Normalize([0.5], [0.5])     # 归一化到[-1, 1]
])

# MNIST数据集加载
dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True
)

# 生成器定义(全连接网络)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            # 输入:latent_dim维噪声
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, img_size**2),  # 输出28x28=784维
            nn.Tanh()  # 输出范围[-1,1] 与输入归一化匹配
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, img_size, img_size)
        return img

# 判别器定义
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size**2, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出概率值
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# 初始化网络
generator = Generator()
discriminator = Discriminator()

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

# 定义损失函数和优化器
adversarial_loss = nn.BCELoss()  # 二分类交叉熵

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# 训练循环
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # ---------------------
        #  训练判别器
        # ---------------------
        
        # 真实数据
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)
        
        # 标签处理
        valid = torch.ones(batch_size, 1).to(device)  # 真实样本标签为1
        fake = torch.zeros(batch_size, 1).to(device)  # 生成样本标签为0

        # 梯度清零
        optimizer_D.zero_grad()

        # 真实样本的损失
        real_loss = adversarial_loss(discriminator(real_imgs), valid)

        # 生成假样本
        z = torch.randn(batch_size, latent_dim).to(device)  # 噪声采样
        gen_imgs = generator(z).detach()  # 阻断生成器梯度传播

        # 假样本的损失
        fake_loss = adversarial_loss(discriminator(gen_imgs), fake)

        # 总判别器损失
        d_loss = (real_loss + fake_loss) / 2

        # 反向传播+优化
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  训练生成器
        # -----------------

        optimizer_G.zero_grad()

        # 生成新样本(需要梯度)
        gen_imgs = generator(z)

        # 生成器目标:让判别器认为生成样本为真
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        # 反向传播+优化
        g_loss.backward()
        optimizer_G.step()

        # 打印训练信息
        if i % 200 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
                  f"D Loss: {d_loss.item():.4f} G Loss: {g_loss.item():.4f}")

# 生成示例图像
z = torch.randn(16, latent_dim).to(device)
gen_imgs = generator(z).detach().cpu()

# 可视化结果
plt.figure(figsize=(4,4))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(gen_imgs[i].squeeze(), cmap='gray')
    plt.axis('off')
plt.show()

3. 代码解释

3.1 生成器网络

nn.Linear(latent_dim, 256)

- 将100维噪声向量映射到256维隐藏层  
- 典型设计:逐层扩大维度,最后收缩到图像尺寸

nn.Tanh()  # 输出范围[-1,1]

- 匹配输入数据的归一化范围(transforms.Normalize([0.5], [0.5]))  
- 确保生成图像像素值在合理区间

3.2 判别器网络

nn.Sigmoid()  # 输出概率值

- 将判别器输出压缩到[0,1]区间  
- 表示输入图像为真实样本的概率

3.3 训练循环关键代码

real_loss = adversarial_loss(discriminator(real_imgs), valid)

- 计算真实样本的交叉熵损失  
- 目标:让判别器对真实样本输出接近1

gen_imgs = generator(z).detach()

- `detach()` 切断生成器梯度传播(仅在训练D时)  
- 防止判别器梯度影响生成器参数

g_loss = adversarial_loss(discriminator(gen_imgs), valid)

- 生成器训练目标:让判别器对生成样本打高分  
- 注意此处没有detach,允许梯度传播到G

3.4 优化器配置

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

- Adam优化器的动量参数  
- b1=0.9通常效果更好

4. 改进建议

1)架构升级
   - DCGAN:使用卷积层代替全连接  

# 生成器示例
nn.ConvTranspose2d(256, 128, 4, 2, 1)  # 转置卷积上采样
nn.BatchNorm2d(128)
nn.ReLU()

2)稳定训练技巧
   - 标签平滑:将真实标签设为0.9~1.0  
   - 单侧标签平滑:仅对真实样本应用

3)损失函数改进
   - Wasserstein GAN(WGAN):  

discriminator.model[-1] = nn.Identity()  # 移除Sigmoid
optimizer_D = RMSprop()  # 使用RMSprop

4)评估指标集成

# 计算FID分数
from pytorch_fid import calculate_fid
fid_value = calculate_fid(real_imgs, gen_imgs)

5)生产级部署
   - ONNX导出:`torch.onnx.export()`  
   - TensorRT加速

6)大模型应用
   - StyleGAN的渐进式训练策略  
   - BigGAN的大批量训练技术  
   - GAN与Diffusion Model的混合架构

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐