PyTorch生成式人工智能——VQ-GAN详解与实现
向量量化生成对抗网络 (Vector Quantized-Generative Adversarial Network, VQ-GAN) 结合了向量量化变分子自编码器 (VQ-VAE) 和生成对抗网络 (GAN) 的优势,通过向量量化 (Vector Quantised, VQ) 构建离散潜空间,结合感知损失与对抗损失,在保持语义一致性的同时能够生成纹理细节丰富的图像。在本节中,我们将详细解释 V
PyTorch生成式人工智能——VQ-GAN详解与实现
0. 前言
随着生成模型的快速发展,如何兼顾图像质量与可建模性一直是该领域研究的核心问题。虽然生成对抗网络 (Generative Adversarial Network, GAN) 擅长生成高保真细节,但训练不稳定;向量量化变分子自编码器 (Vector Quantised-Variational AutoEncoder, VQ-VAE) 能提供离散潜表示,适合 Transformer 等序列建模,但细节不足。向量量化生成对抗网络 (Vector Quantized-Generative Adversarial Network
, VQ-GAN
) 结合了两者的优势,通过向量量化 (Vector Quantised
, VQ
) 构建离散潜空间,结合感知损失与对抗损失,在保持语义一致性的同时能够生成纹理细节丰富的图像。在本节中,我们将详细解释 VQ-GAN
的技术原理和各个组成,并使用 PyTorch
从零开始实现 VQ-GAN
。
1. VQ-GAN 核心思想
向量量化生成对抗网络 (Vector Quantized-Generative Adversarial Network
, VQ-GAN
) 把向量量化 (Vector Quantization
, VQ
) 离散表示和生成对抗网络 (Generative Adversarial Network, GAN) 的判别式重建能力结合起来。核心思想如下:
- 使用编码器把图像映射到低分辨率的连续潜空间 (
latent
) - 在潜空间做离散化(向量量化) —— 用一个码本 (
codebook
,可以将其理解为词汇表) 将连续向量替换为最近的离散嵌入向量,得到离散符号序列 - 解码器(生成器)从离散嵌入恢复高质量图像
- 同时用感知损失 (
VGG
特征损失) 和对抗损失 (PatchGAN 等)鼓励生成图像既在像素上逼真又在感知特征上优良
离散化后的符号可以被下游模型(如自回归 Transformer )学习,从而实现高分辨率图像生成。离散表示便于建模,同时 GAN
的对抗训练能生成高保真细节,感知损失保证语义一致性。
2. VQ-GAN 架构
在向量量化变分子自编码器 (Vector Quantised-Variational AutoEncoder, VQ-VAE) 中,先把图像编码到潜空间,再把每个潜向量替换为最近的码本 (codebook
) 向量,训练损失包含重建误差、码本损失和 commitment
损失,输出是离散索引序列。而生成对抗网络 (Generative Adversarial Network, GAN) 通过对抗式训练使生成分布接近真实分布,能产生细节丰富的样本。感知损失基于预训练网络(如 VGG )中间层特征的差异度量,优于像素级 L2/L1
在“感知上相似”方面的表现。VQ-GAN
将以上要素进行合并,用 VQ
将潜表示离散化,用 GAN
与 VGG
应对生成质量与感知一致性,模型架构如下所示,总体上包含四个模块:
-
编码器 E E E,输入图像 x ∈ R H × W × 3 x\in \mathbb R^{H\times W\times 3} x∈RH×W×3,输出一个下采样后的连续特征图 z e = E ( x ) ∈ R h × w × d z_e=E(x)\in \mathbb R^{h×w×d} ze=E(x)∈Rh×w×d,其中 h = H / s h=H/s h=H/s, w = W / s w=W/s w=W/s ( s s s 为下采样比,比如
8
或16
), d d d 是通道维度 -
向量量化
Codebook
ε = { e k ∈ R d } k = 1 K \varepsilon=\{e_k∈\mathbb R^d\}_{k=1}^K ε={ek∈Rd}k=1K,对编码器输出的每个向量 z e ( i ) z_e(i) ze(i) 做最近邻映射:
k ∗ = a r g m i n k ∣ ∣ z e ( i ) − e k ∣ ∣ 2 , z q ( i ) = e k ∗ k^∗=\underset k {argmin}||z_e(i)−e_k||_2,z_q(i)=e_{k^∗} k∗=kargmin∣∣ze(i)−ek∣∣2,zq(i)=ek∗
最终的量化输出 z q z_q zq 与索引序列 { k ∗ } \{k∗\} {k∗} 都可用于后续建模 -
解码器 (
Decoder
) G G G 将离散嵌入 z q z_q zq 重建回原图像: x ^ = G ( z q ) \hat x=G(z_q) x^=G(zq)。解码器通常为卷积残差网络(或U-Net
风格)并逐步上采样回原始分辨率 -
判别器 (
Discriminator
) D D D (常用PatchGAN
) 用于判别真/假图像,提供对抗损失,迫使G
生成局部与全局更逼真的细节
此外,训练中会用到预训练的 VGG
网络 ϕ ( ⋅ ) ϕ(⋅) ϕ(⋅) 计算感知 (perceptual
) 损失。
3. 损失函数
VQ-GAN
的总损失通常由以下几部分组成:
3.1 重建损失(像素/感知)
像素损失(可选):
L r e c p i x = ∣ ∣ x − x ^ ∣ ∣ 1 或 ∣ ∣ x − x ^ ∣ ∣ 2 2 L_{rec}^{pix}=||x−\hat x||_1\ 或\ ||x−\hat x||_2^2 Lrecpix=∣∣x−x^∣∣1 或 ∣∣x−x^∣∣22
感知损失(推荐):
L p e r c = ∑ l λ l ∣ ∣ ϕ l ( x ) − ϕ l ( x ^ ) ∣ ∣ 2 2 L_{perc}=∑_lλ_l||ϕ_l(x)−ϕ_l(\hat x)||_2^2 Lperc=l∑λl∣∣ϕl(x)−ϕl(x^)∣∣22
其中 ϕ l ϕ_l ϕl 表示 VGG
的第 l l l 层特征, λ l λ_l λl 是层权重,感知损失能保持语义与结构。
3.2 向量量化相关的损失 (codebook/commitment)
VQ-VAE
通常包含两项(使用 stop-gradient
):
L v q = ∣ ∣ s g [ z e ] − e ∣ ∣ 2 2 ⏟ c o d e b o o k l o s s + β ∣ ∣ z e − s g [ e ] ∣ ∣ 2 2 ⏟ c o m m i t m e n t l o s s L_{vq}=\underbrace {||sg[z_e]−e||_2^2}_{codebook\ loss}+\underbrace {β||z_e−sg[e]||_2^2}_{commitment\ loss} Lvq=codebook loss
∣∣sg[ze]−e∣∣22+commitment loss
β∣∣ze−sg[e]∣∣22
其中 s g [ ⋅ ] sg[⋅] sg[⋅] 表示 stop-gradient
(将张量视为常数,不反向传播)。codebook loss
更新 codebook
向量,commitment loss
推动编码器输出更接近 codebook
向量。
3.3 对抗损失
判别器损失 L D L_D LD (常用 hinge loss
):
L D = E x [ m a x ( 0 , 1 − D ( x ) ) ] + E x ^ [ m a x ( 0 , 1 + D ( x ^ ) ) ] L_D=\mathbb E_x[max(0,1−D(x))]+\mathbb E_{\hat x}[max(0,1+D(\hat x))] LD=Ex[max(0,1−D(x))]+Ex^[max(0,1+D(x^))]
生成器的对抗损失(使 D
判为真):
L a d v = − E x ^ [ D ( x ^ ) ] L_{adv}=−\mathbb E_{\hat x}[D(\hat x)] Ladv=−Ex^[D(x^)]
使用 GAN
能够帮助恢复纹理细节和真实感。
3.4 总损失
总损失如下:
L G = λ r e c L r e c p i x + λ p e r c L p e r c + λ a d v L a d v + L v q L_G=λ_{rec}L_{rec}^{pix}+λ_{perc}L_{perc}+λ_{adv}L_{adv}+L_{vq} LG=λrecLrecpix+λpercLperc+λadvLadv+Lvq
其中 λ λ λ 控制不同项比重,通常 λ a d v λ_{adv} λadv 通常很小(如 0.1
或更小),而 λ p e r c λ_{perc} λperc 较高以维持图像结构。
4. 实现 VQ-GAN
接下来,我们使用 Hugging Face
的 huggan/selfie2anime
数据集(仅动漫脸部分 imageB
)训练 VQ-GAN
4.1 加载数据集
(1) 导入所需库,并定义超参数和设备:
import os
import math
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from torchvision.models import vgg16
from torchvision import models
from PIL import Image
from tqdm import tqdm
import io
from datasets import load_dataset
# ---------- 超参数 ----------
IMG_SIZE = 128 # 训练分辨率,受显存限制可调
BATCH_SIZE = 16
LR = 2e-4
EPOCHS = 100
CODEBOOK_SIZE = 1024 # 码本大小
CODEBOOK_DIM = 256 # 每个码本向量的维度
EMA_DECAY = 0.99 # 如果使用 EMA 更新码本
LATENT_CHANNELS = 256 # 编码器输出通道数(在量化前)
SAVE_DIR = "./checkpoints_vqgan"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
(2) 包装 Hugging Face dataset
为 PyTorch Dataset
,返回预处理后的图像:
class HFImageDataset(Dataset):
def __init__(self, hf_split, image_size=IMG_SIZE):
self.ds = hf_split
self.image_size = image_size
self.tf = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=Image.BICUBIC),
transforms.CenterCrop(image_size),
transforms.ToTensor(), # [0,1]
transforms.Normalize([0.5]*3, [0.5]*3) # [-1,1]
])
self.col = 'imageB'
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
item = self.ds[idx][self.col]
# item may already be PIL.Image or dict with 'path' or 'bytes'
if isinstance(item, dict) and 'path' in item:
img = Image.open(item['path']).convert('RGB')
elif isinstance(item, dict) and 'bytes' in item:
img = Image.open(io.BytesIO(item['bytes'])).convert('RGB')
elif isinstance(item, Image.Image):
img = item.convert('RGB')
else:
# fallback: try to open if it's a string path
try:
img = Image.open(item).convert('RGB')
except Exception:
# last resort: assume HF returned PIL-compatible object
img = Image.fromarray(item).convert('RGB')
return self.tf(img)
hf_ds = load_dataset("huggan/selfie2anime", split="train")
dataset = HFImageDataset(hf_ds, image_size=IMG_SIZE)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
print("Dataset size:", len(dataset))
4.2 模型构建与训练
(1) 定义向量量化模块,返回量化输出与损失项:
class VectorQuantizerEMA(nn.Module):
def __init__(self, num_tokens, dim, decay=EMA_DECAY, eps=1e-5):
super().__init__()
self.num_tokens = num_tokens
self.dim = dim
self.decay = decay
self.eps = eps
# codebook
embed = torch.randn(num_tokens, dim)
self.register_buffer('embedding', embed)
self.register_buffer('cluster_size', torch.zeros(num_tokens))
self.register_buffer('embed_avg', embed.clone())
def forward(self, z):
b, c, h, w = z.shape
flat = z.permute(0, 2, 3, 1).contiguous().view(-1, c) # N x C
# compute distances to embeddings
embed = self.embedding # K x C
# dist = ||x - e||^2 = x^2 + e^2 - 2 x e
distances = (flat.pow(2).sum(1, keepdim=True)
+ embed.pow(2).sum(1)
- 2 * flat @ embed.t()) # N x K
encoding_indices = torch.argmin(distances, dim=1) # N
encodings = F.one_hot(encoding_indices, self.num_tokens).type(flat.dtype) # N x K
quantized = (encodings @ embed).view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
if self.training:
# update cluster_size and embed_avg
n_enc = encodings.sum(0) # K
embed_sum = encodings.t() @ flat # K x C
# EMA updates
self.cluster_size.data.mul_(self.decay).add_(n_enc, alpha=1 - self.decay)
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
# normalize to get new embeddings
n = self.cluster_size.sum()
cluster_size = ((self.cluster_size + self.eps) / (n + self.num_tokens * self.eps)) * n
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embedding.data.copy_(embed_normalized)
# commitment loss
commitment_loss = F.mse_loss(quantized.detach(), z)
# straight-through estimator
quantized = z + (quantized - z).detach()
return quantized, commitment_loss, encoding_indices.view(b, h, w)
(2) 定义编码器与解码器:
def conv_block(in_ch, out_ch, kernel=3, stride=1, padding=1):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel, stride=stride, padding=padding),
nn.BatchNorm2d(out_ch),
nn.GELU()
)
class Encoder(nn.Module):
def __init__(self, in_ch=3, base_ch=64, latent_ch=LATENT_CHANNELS):
super().__init__()
self.conv1 = conv_block(in_ch, base_ch, 4, 2, 1) # 128 -> 64
self.conv2 = conv_block(base_ch, base_ch*2, 4, 2, 1) # 64 -> 32
self.conv3 = conv_block(base_ch*2, base_ch*4, 4, 2, 1) # 32 -> 16
self.conv4 = conv_block(base_ch*4, latent_ch, 4, 2, 1) # 16 -> 8 (downsample factor 16)
# optionally more resblocks / attention can be added
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
z = self.conv4(x)
return z
class Decoder(nn.Module):
def __init__(self, out_ch=3, base_ch=64, latent_ch=LATENT_CHANNELS):
super().__init__()
self.deconv1 = nn.Sequential(
nn.ConvTranspose2d(latent_ch, base_ch*4, 4, 2, 1),
nn.BatchNorm2d(base_ch*4),
nn.GELU()
) # 8 -> 16
self.deconv2 = nn.Sequential(
nn.ConvTranspose2d(base_ch*4, base_ch*2, 4, 2, 1),
nn.BatchNorm2d(base_ch*2),
nn.GELU()
) # 16 -> 32
self.deconv3 = nn.Sequential(
nn.ConvTranspose2d(base_ch*2, base_ch, 4, 2, 1),
nn.BatchNorm2d(base_ch),
nn.GELU()
) # 32 -> 64
self.deconv4 = nn.Sequential(
nn.ConvTranspose2d(base_ch, out_ch, 4, 2, 1),
nn.Tanh()
) # 64 -> 128
def forward(self, z):
x = self.deconv1(z)
x = self.deconv2(x)
x = self.deconv3(x)
x = self.deconv4(x)
return x
(3) 实现 PatchGAN
判别器,用于对抗训练:
class PatchDiscriminator(nn.Module):
def __init__(self, in_ch=3, base=64):
super().__init__()
layers = [
nn.Conv2d(in_ch, base, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base, base*2, 4, 2, 1), nn.BatchNorm2d(base*2), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base*2, base*4, 4, 2, 1), nn.BatchNorm2d(base*4), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base*4, 1, 4, 1, 1) # 输出 patch 判别图
]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
(4) 用预训练 VGG16
的中间层特征计算感知损失:
class VGGPerceptual(nn.Module):
def __init__(self, resize=True):
super().__init__()
vgg = vgg16(weights=models.VGG16_Weights.DEFAULT).features.eval()
self.slice1 = nn.Sequential(*[vgg[x] for x in range(4)]) # relu1_2
self.slice2 = nn.Sequential(*[vgg[x] for x in range(4, 9)]) # relu2_2
# freeze
for p in self.parameters():
p.requires_grad = False
self.resize = resize
def forward(self, x, y):
x = (x + 1) * 0.5
y = (y + 1) * 0.5
f1x = self.slice1(x)
f1y = self.slice1(y)
f2x = self.slice2(f1x)
f2y = self.slice2(f1y)
loss = F.mse_loss(f1x, f1y) + F.mse_loss(f2x, f2y)
return loss
(5) 将编码器、向量量化器和解码器封装为一个模块:
class VQGAN(nn.Module):
def __init__(self, in_ch=3, base_ch=64, latent_ch=LATENT_CHANNELS,
codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM):
super().__init__()
self.encoder = Encoder(in_ch, base_ch, latent_ch)
# project encoder output to codebook dim if latent_ch != codebook_dim
self.enc_proj = nn.Conv2d(latent_ch, codebook_dim, 1)
self.quantize = VectorQuantizerEMA(codebook_size, codebook_dim)
self.dec_proj = nn.Conv2d(codebook_dim, latent_ch, 1)
self.decoder = Decoder(in_ch, base_ch, latent_ch)
def forward(self, x):
z_e = self.encoder(x) # B, latent_ch, H', W'
z_e_proj = self.enc_proj(z_e) # B, codebook_dim, H', W'
z_q, commit_loss, indices = self.quantize(z_e_proj)
z_q = self.dec_proj(z_q)
x_rec = self.decoder(z_q)
return x_rec, commit_loss, indices
(6) 实现训练循环,并执行模型训练:
def train():
model = VQGAN().to(device)
disc = PatchDiscriminator().to(device)
vgg = VGGPerceptual().to(device)
opt_gen = torch.optim.Adam(list(model.parameters()), lr=LR, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(disc.parameters(), lr=LR, betas=(0.5, 0.9))
l1_loss = nn.L1Loss()
global_step = 0
for epoch in range(EPOCHS):
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
for imgs in pbar:
imgs = imgs.to(device) # B x 3 x H x W
# Train Discriminator
model.eval()
with torch.no_grad():
recons, _, _ = model(imgs)
real_logits = disc(imgs)
fake_logits = disc(recons.detach())
d_loss = (F.relu(1.0 - real_logits).mean() + F.relu(1.0 + fake_logits).mean())
opt_disc.zero_grad()
d_loss.backward()
opt_disc.step()
# Train Generator (VQGAN)
model.train()
recons, commit_loss, _ = model(imgs)
# reconstruction L1
rec_loss = l1_loss(recons, imgs)
# perceptual
p_loss = vgg(recons, imgs)
# adversarial (generator wants discriminator to output high)
adv_logits = disc(recons)
g_adv_loss = -adv_logits.mean()
lambda_rec = 1.0
lambda_perc = 1.0
lambda_gan = 0.1
lambda_commit = 0.25
g_loss = (lambda_rec * rec_loss +
lambda_perc * p_loss +
lambda_gan * g_adv_loss +
lambda_commit * commit_loss)
opt_gen.zero_grad()
g_loss.backward()
opt_gen.step()
global_step += 1
if global_step % 100 == 0:
# save sample grid
with torch.no_grad():
samples = torch.cat([imgs[:8], recons[:8]], dim=0)
samples = (samples + 1) * 0.5 # to [0,1]
grid = utils.make_grid(samples, nrow=8)
utils.save_image(grid, os.path.join(SAVE_DIR, f"sample_{global_step}.png"))
pbar.set_postfix({
"d_loss": d_loss.item(),
"g_loss": g_loss.item(),
"rec": rec_loss.item(),
"perc": p_loss.item(),
"commit": commit_loss.item()
})
# end epoch checkpoint
torch.save({
"epoch": epoch,
"model_state": model.state_dict(),
"disc_state": disc.state_dict(),
"opt_gen": opt_gen.state_dict(),
"opt_disc": opt_disc.state_dict()
}, os.path.join(SAVE_DIR, f"vqgan_epoch_{epoch+1}.pt"))
if __name__ == "__main__":
train()
可以看到随着训练的进行,图像重建效果逐渐提高,本文专注于训练 VQ-GAN
,训练完成后,可以通过训练 transformer 在量化后的 token
上生成图像。
相关链接
PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(24)——使用PyTorch构建Transformer模型
PyTorch生成式人工智能(25)——基于Transformer实现机器翻译
PyTorch生成式人工智能(26)——使用PyTorch构建GPT模型
PyTorch生成式人工智能(27)——从零开始训练GPT模型
PyTorch生成式人工智能(28)——MuseGAN详解与实现
PyTorch生成式人工智能——VQ-VAE详解与实现
更多推荐
所有评论(0)