第 11 章 生成模型与 GAN¶
📚 章节概述¶
本章介绍生成模型的核心技术,包括 GAN 、 VAE 、扩散模型等。生成模型是计算机视觉的前沿方向,广泛应用于图像生成、风格迁移、数据增强等领域。
学习时间: 5-7 天 难度等级:⭐⭐⭐⭐⭐ 前置知识:第 5-6 章
🎯 学习目标¶
完成本章后,你将能够: - 理解生成模型的基本原理 - 掌握 GAN 的训练技巧 - 了解 VAE 和扩散模型 - 能够实现图像生成应用 - 完成图像生成项目
11.1 GAN 基础¶
11.1.1 GAN 原理¶
核心思想: - 生成器( Generator ):生成假样本 - 判别器( Discriminator ):区分真假 - 对抗训练:零和博弈
损失函数:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super().__init__()
self.img_shape = img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, momentum=0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape) # 重塑张量形状
return img
class Discriminator(nn.Module):
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
11.1.2 DCGAN¶
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim=100):
super(DCGANGenerator, self).__init__()
self.model = nn.Sequential(
# 输入: latent_dim x 1 x 1
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 512 x 4 x 4
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 256 x 8 x 8
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 128 x 16 x 16
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 64 x 32 x 32
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh()
# 3 x 64 x 64
)
def forward(self, z):
return self.model(z)
11.2 VAE (Variational Autoencoder)¶
class VAE(nn.Module):
def __init__(self, latent_dim=20):
super(VAE, self).__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 4, 2, 1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(64 * 8 * 8, 256),
nn.ReLU()
)
# 均值和方差
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 64 * 8 * 8),
nn.ReLU(),
nn.Unflatten(1, (64, 8, 8)),
nn.ConvTranspose2d(64, 32, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, 4, 2, 1),
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# 损失函数
def vae_loss(recon_x, x, mu, logvar):
# 该写法默认像素已归一化到 [0, 1],且重建分布按 Bernoulli 近似;连续像素建模时常改用 MSE/高斯似然。
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') # F.xxx PyTorch函数式API
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
11.3 扩散模型¶
class SimpleDenoiser(nn.Module):
"""教学用途的极简去噪网络,占位替代未定义的 UNet。"""
def __init__(self, in_channels=3):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(64, in_channels, 3, padding=1),
)
def forward(self, x, t):
return self.net(x)
class DiffusionModel(nn.Module):
def __init__(self, timesteps=1000):
super().__init__()
self.timesteps = timesteps
beta = torch.linspace(0.0001, 0.02, timesteps)
self.register_buffer('beta', beta)
self.register_buffer('alpha', 1 - beta)
self.register_buffer('alpha_hat', torch.cumprod(1 - beta, dim=0))
self.model = SimpleDenoiser()
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t]).view(-1, 1, 1, 1)
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t]).view(-1, 1, 1, 1)
return sqrt_alpha_hat * x_start + sqrt_one_minus_alpha_hat * noise
def p_sample(self, x, t):
predicted_noise = self.model(x, t)
alpha = self.alpha[t].view(-1, 1, 1, 1)
alpha_hat = self.alpha_hat[t].view(-1, 1, 1, 1)
beta = self.beta[t].view(-1, 1, 1, 1)
mean = (x - (1 - alpha) / torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha)
noise = torch.randn_like(x)
nonzero_mask = (t > 0).float().view(-1, 1, 1, 1)
return mean + nonzero_mask * torch.sqrt(beta) * noise
def forward(self, x):
t = torch.randint(0, self.timesteps, (x.size(0),), device=x.device)
noise = torch.randn_like(x)
x_noisy = self.q_sample(x, t, noise)
predicted_noise = self.model(x_noisy, t)
return F.mse_loss(predicted_noise, noise)
11.4 实战案例:图像生成¶
import torch.optim as optim
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
latent_dim = 100
generator = Generator(latent_dim, (3, 64, 64)).to(device)
discriminator = Discriminator((3, 64, 64)).to(device)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss()
def train_gan(dataloader, epochs=50):
if len(dataloader) == 0:
raise ValueError("dataloader 为空,无法开始 GAN 训练")
for epoch in range(epochs):
for i, (imgs, _) in enumerate(dataloader):
imgs = imgs.to(device)
batch_size = imgs.size(0)
valid = torch.ones(batch_size, 1, device=device)
fake = torch.zeros(batch_size, 1, device=device)
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(imgs), valid)
z = torch.randn(batch_size, latent_dim, device=device)
gen_imgs = generator(z)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
optimizer_G.zero_grad()
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
if i % 100 == 0:
print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
train_gan(dataloader, epochs=50)
11.5 练习题¶
基础题¶
- 简答题:
- GAN 的工作原理是什么?
GAN 由生成器( Generator )和判别器( Discriminator )组成,进行对抗博弈:生成器从随机噪声生成假样本试图欺骗判别器,判别器则尝试区分真假样本。两者交替训练,理想情况下会逼近一种动态平衡,使生成样本越来越接近真实分布。训练目标通常写为 \(\min_G \max_D \mathbb{E}[\log D(x)] + \mathbb{E}[\log(1-D(G(z)))]\)。
- VAE 和 GAN 有什么区别?
VAE通过编码器将输入映射为潜在分布(均值+方差),用重参数化技巧采样后由解码器重建,优化重建损失+KL 散度,生成结果较模糊但训练稳定、有显式概率密度。GAN通过对抗训练学习隐式分布,生成结果更清晰锐利,但训练不稳定、易出现模式崩溃,且无法直接进行概率推断。
进阶题¶
- 编程题:
- 实现一个简单的 GAN 。
- 使用 VAE 生成图像。
11.6 关键复盘¶
高频复盘题¶
Q1: GAN 的训练难点是什么?
参考答案: - 模式崩溃( Mode Collapse ) - 训练不稳定 - 判别器过强/过弱 - 梯度消失/爆炸 - 解决方案: - Wasserstein GAN - 梯度惩罚 - 谱归一化
Q2: 扩散模型的原理是什么?
参考答案: - 前向过程:逐步添加噪声 - 反向过程:逐步去噪 - 训练:预测添加的噪声 - 采样:从噪声生成图像 - 优势:训练稳定、生成质量高
11.7 本章小结¶
核心知识点¶
- GAN:生成器、判别器、对抗训练
- DCGAN:卷积 GAN
- VAE:变分自编码器
- 扩散模型:前向扩散、反向去噪
下一步¶
下一章:12-视觉 Transformer.md - 学习 ViT
恭喜完成第 11 章! 🎉
⚠️ 核验说明(2026-04-03):本页已完成逐段人工复核,并为 VAE/GAN 示例补充了损失函数适用前提与空数据加载器检查。若文中涉及外部模型、API、版本号、价格、部署依赖或第三方产品名称,请以官方文档、论文原文和实际运行环境为准。
最后更新日期: 2026-04-03