GAN生成对抗网络原理解析及基于pytorch的代码实现
生成对抗网络(GAN)通过博弈论框架实现数据生成,其核心是生成器(Generator)与判别器(Discriminator)的对抗训练:1)双网络架构- 生成器(G):将随机噪声z映射到数据空间,G(z) → 伪样本- 判别器(D):区分真实数据(P_data)与生成数据(P_g),输出概率值- 目标函数(极小极大博弈):$$ \min_G \max_D V(D,G) = \mathbb{E}_{
1. GAN底层原理解析
1.1 核心思想
生成对抗网络(GAN)通过博弈论框架实现数据生成,其核心是生成器(Generator)与判别器(Discriminator)的对抗训练:
1)双网络架构
- 生成器(G):将随机噪声z映射到数据空间,G(z) → 伪样本
- 判别器(D):区分真实数据()与生成数据(
),输出概率值
- 目标函数(极小极大博弈):
$$ \min_G \max_D V(D,G) = \mathbb{E}_{x\sim p_{data}}[\log D(x)] + \mathbb{E}_{z\sim p_z}[\log(1-D(G(z)))] $$
2)纳什均衡
- 理想状态:当 时达到平衡(D无法区分真伪)
- 数学证明:当G固定时,最优判别器
$$ D^*(x) = \frac{P_{data}(x)}{P_{data}(x) + P_g(x)} $$
3)训练动态
- 判别器训练:最大化真实样本得分,最小化生成样本得分
- 生成器训练:欺骗判别器,使其对G(z)给出高评分
- 交替训练策略(关键实现细节)
1.2 数学推导
1)KL散度视角
GAN的目标函数等价于最小化 与
的JS散度:
$$ JSD(P_{data} \| P_g) = \frac{1}{2}KL(P_{data} \| \frac{P_{data}+P_g}{2}) + \frac{1}{2}KL(P_g \| \frac{P_{data}+P_g}{2}) $$
1.3 训练难点
1)模式崩溃(Mode Collapse)
- 生成器仅生成有限模式样本
- 解决方案:Mini-batch Discrimination、Unrolled GAN
2)梯度消失
- 当D过于强大时,G的梯度趋近于零
- 改进目标函数:NS-GAN使用
3)评估指标
- Inception Score (IS)、Fréchet Inception Distance (FID)
- 避免主观评价偏差
2. 基于PyTorch代码实现
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 超参数配置
latent_dim = 100 # 噪声向量维度
img_size = 28 # 图像尺寸
batch_size = 128 # 批大小
epochs = 50 # 训练轮数
lr = 0.0002 # 学习率
b1 = 0.5 # Adam优化器参数
b2 = 0.999
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转换为Tensor [0,1]
transforms.Normalize([0.5], [0.5]) # 归一化到[-1, 1]
])
# MNIST数据集加载
dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True
)
# 生成器定义(全连接网络)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
# 输入:latent_dim维噪声
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, img_size**2), # 输出28x28=784维
nn.Tanh() # 输出范围[-1,1] 与输入归一化匹配
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, img_size, img_size)
return img
# 判别器定义
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(img_size**2, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid() # 输出概率值
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# 初始化网络
generator = Generator()
discriminator = Discriminator()
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
# 定义损失函数和优化器
adversarial_loss = nn.BCELoss() # 二分类交叉熵
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
# 训练循环
for epoch in range(epochs):
for i, (imgs, _) in enumerate(dataloader):
# ---------------------
# 训练判别器
# ---------------------
# 真实数据
real_imgs = imgs.to(device)
batch_size = real_imgs.size(0)
# 标签处理
valid = torch.ones(batch_size, 1).to(device) # 真实样本标签为1
fake = torch.zeros(batch_size, 1).to(device) # 生成样本标签为0
# 梯度清零
optimizer_D.zero_grad()
# 真实样本的损失
real_loss = adversarial_loss(discriminator(real_imgs), valid)
# 生成假样本
z = torch.randn(batch_size, latent_dim).to(device) # 噪声采样
gen_imgs = generator(z).detach() # 阻断生成器梯度传播
# 假样本的损失
fake_loss = adversarial_loss(discriminator(gen_imgs), fake)
# 总判别器损失
d_loss = (real_loss + fake_loss) / 2
# 反向传播+优化
d_loss.backward()
optimizer_D.step()
# -----------------
# 训练生成器
# -----------------
optimizer_G.zero_grad()
# 生成新样本(需要梯度)
gen_imgs = generator(z)
# 生成器目标:让判别器认为生成样本为真
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
# 反向传播+优化
g_loss.backward()
optimizer_G.step()
# 打印训练信息
if i % 200 == 0:
print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
f"D Loss: {d_loss.item():.4f} G Loss: {g_loss.item():.4f}")
# 生成示例图像
z = torch.randn(16, latent_dim).to(device)
gen_imgs = generator(z).detach().cpu()
# 可视化结果
plt.figure(figsize=(4,4))
for i in range(16):
plt.subplot(4,4,i+1)
plt.imshow(gen_imgs[i].squeeze(), cmap='gray')
plt.axis('off')
plt.show()
3. 代码解释
3.1 生成器网络
nn.Linear(latent_dim, 256)
- 将100维噪声向量映射到256维隐藏层
- 典型设计:逐层扩大维度,最后收缩到图像尺寸
nn.Tanh() # 输出范围[-1,1]
- 匹配输入数据的归一化范围(transforms.Normalize([0.5], [0.5]))
- 确保生成图像像素值在合理区间
3.2 判别器网络
nn.Sigmoid() # 输出概率值
- 将判别器输出压缩到[0,1]区间
- 表示输入图像为真实样本的概率
3.3 训练循环关键代码
real_loss = adversarial_loss(discriminator(real_imgs), valid)
- 计算真实样本的交叉熵损失
- 目标:让判别器对真实样本输出接近1
gen_imgs = generator(z).detach()
- `detach()` 切断生成器梯度传播(仅在训练D时)
- 防止判别器梯度影响生成器参数
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
- 生成器训练目标:让判别器对生成样本打高分
- 注意此处没有detach,允许梯度传播到G
3.4 优化器配置
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
- Adam优化器的动量参数
- b1=0.9通常效果更好
4. 改进建议
1)架构升级
- DCGAN:使用卷积层代替全连接
# 生成器示例
nn.ConvTranspose2d(256, 128, 4, 2, 1) # 转置卷积上采样
nn.BatchNorm2d(128)
nn.ReLU()
2)稳定训练技巧
- 标签平滑:将真实标签设为0.9~1.0
- 单侧标签平滑:仅对真实样本应用
3)损失函数改进
- Wasserstein GAN(WGAN):
discriminator.model[-1] = nn.Identity() # 移除Sigmoid
optimizer_D = RMSprop() # 使用RMSprop
4)评估指标集成
# 计算FID分数
from pytorch_fid import calculate_fid
fid_value = calculate_fid(real_imgs, gen_imgs)
5)生产级部署
- ONNX导出:`torch.onnx.export()`
- TensorRT加速
6)大模型应用
- StyleGAN的渐进式训练策略
- BigGAN的大批量训练技术
- GAN与Diffusion Model的混合架构
更多推荐
所有评论(0)