跳转至

第 11 章 生成模型与 GAN

📚 章节概述

本章介绍生成模型的核心技术,包括 GAN 、 VAE 、扩散模型等。生成模型是计算机视觉的前沿方向,广泛应用于图像生成、风格迁移、数据增强等领域。

学习时间: 5-7 天 难度等级:⭐⭐⭐⭐⭐ 前置知识:第 5-6 章

🎯 学习目标

完成本章后,你将能够: - 理解生成模型的基本原理 - 掌握 GAN 的训练技巧 - 了解 VAE 和扩散模型 - 能够实现图像生成应用 - 完成图像生成项目


11.1 GAN 基础

11.1.1 GAN 原理

核心思想: - 生成器( Generator ):生成假样本 - 判别器( Discriminator ):区分真假 - 对抗训练:零和博弈

损失函数

Text Only
L_D = -E[log(D(x))] - E[log(1 - D(G(z)))]
L_G = -E[log(D(G(z)))]
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, momentum=0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)  # 重塑张量形状
        return img

class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

11.1.2 DCGAN

Python
class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100):
        super(DCGANGenerator, self).__init__()

        self.model = nn.Sequential(
            # 输入: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            # 512 x 4 x 4
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            # 256 x 8 x 8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # 128 x 16 x 16
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 64 x 32 x 32
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # 3 x 64 x 64
        )

    def forward(self, z):
        return self.model(z)

11.2 VAE (Variational Autoencoder)

Python
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()

        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU()
        )

        # 均值和方差
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 64 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (64, 8, 8)),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# 损失函数
def vae_loss(recon_x, x, mu, logvar):
    # 该写法默认像素已归一化到 [0, 1],且重建分布按 Bernoulli 近似;连续像素建模时常改用 MSE/高斯似然。
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')  # F.xxx PyTorch函数式API
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

11.3 扩散模型

Python
class SimpleDenoiser(nn.Module):
    """教学用途的极简去噪网络,占位替代未定义的 UNet。"""

    def __init__(self, in_channels=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(64, in_channels, 3, padding=1),
        )

    def forward(self, x, t):
        return self.net(x)


class DiffusionModel(nn.Module):
    def __init__(self, timesteps=1000):
        super().__init__()
        self.timesteps = timesteps
        beta = torch.linspace(0.0001, 0.02, timesteps)
        self.register_buffer('beta', beta)
        self.register_buffer('alpha', 1 - beta)
        self.register_buffer('alpha_hat', torch.cumprod(1 - beta, dim=0))
        self.model = SimpleDenoiser()

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t]).view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t]).view(-1, 1, 1, 1)
        return sqrt_alpha_hat * x_start + sqrt_one_minus_alpha_hat * noise

    def p_sample(self, x, t):
        predicted_noise = self.model(x, t)
        alpha = self.alpha[t].view(-1, 1, 1, 1)
        alpha_hat = self.alpha_hat[t].view(-1, 1, 1, 1)
        beta = self.beta[t].view(-1, 1, 1, 1)
        mean = (x - (1 - alpha) / torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha)

        noise = torch.randn_like(x)
        nonzero_mask = (t > 0).float().view(-1, 1, 1, 1)
        return mean + nonzero_mask * torch.sqrt(beta) * noise

    def forward(self, x):
        t = torch.randint(0, self.timesteps, (x.size(0),), device=x.device)
        noise = torch.randn_like(x)
        x_noisy = self.q_sample(x, t, noise)
        predicted_noise = self.model(x_noisy, t)
        return F.mse_loss(predicted_noise, noise)

11.4 实战案例:图像生成

Python
import torch.optim as optim
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
latent_dim = 100
generator = Generator(latent_dim, (3, 64, 64)).to(device)
discriminator = Discriminator((3, 64, 64)).to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss()


def train_gan(dataloader, epochs=50):
    if len(dataloader) == 0:
        raise ValueError("dataloader 为空,无法开始 GAN 训练")
    for epoch in range(epochs):
        for i, (imgs, _) in enumerate(dataloader):
            imgs = imgs.to(device)
            batch_size = imgs.size(0)
            valid = torch.ones(batch_size, 1, device=device)
            fake = torch.zeros(batch_size, 1, device=device)

            optimizer_D.zero_grad()
            real_loss = adversarial_loss(discriminator(imgs), valid)

            z = torch.randn(batch_size, latent_dim, device=device)
            gen_imgs = generator(z)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            optimizer_G.zero_grad()
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            g_loss.backward()
            optimizer_G.step()

            if i % 100 == 0:
                print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
                      f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

train_gan(dataloader, epochs=50)

11.5 练习题

基础题

  1. 简答题
  2. GAN 的工作原理是什么?

    GAN 由生成器( Generator )和判别器( Discriminator )组成,进行对抗博弈:生成器从随机噪声生成假样本试图欺骗判别器,判别器则尝试区分真假样本。两者交替训练,理想情况下会逼近一种动态平衡,使生成样本越来越接近真实分布。训练目标通常写为 \(\min_G \max_D \mathbb{E}[\log D(x)] + \mathbb{E}[\log(1-D(G(z)))]\)

  3. VAE 和 GAN 有什么区别?

    VAE通过编码器将输入映射为潜在分布(均值+方差),用重参数化技巧采样后由解码器重建,优化重建损失+KL 散度,生成结果较模糊但训练稳定、有显式概率密度。GAN通过对抗训练学习隐式分布,生成结果更清晰锐利,但训练不稳定、易出现模式崩溃,且无法直接进行概率推断。

进阶题

  1. 编程题
  2. 实现一个简单的 GAN 。
  3. 使用 VAE 生成图像。

11.6 关键复盘

高频复盘题

Q1: GAN 的训练难点是什么?

参考答案: - 模式崩溃( Mode Collapse ) - 训练不稳定 - 判别器过强/过弱 - 梯度消失/爆炸 - 解决方案: - Wasserstein GAN - 梯度惩罚 - 谱归一化

Q2: 扩散模型的原理是什么?

参考答案: - 前向过程:逐步添加噪声 - 反向过程:逐步去噪 - 训练:预测添加的噪声 - 采样:从噪声生成图像 - 优势:训练稳定、生成质量高


11.7 本章小结

核心知识点

  1. GAN:生成器、判别器、对抗训练
  2. DCGAN:卷积 GAN
  3. VAE:变分自编码器
  4. 扩散模型:前向扩散、反向去噪

下一步

下一章12-视觉 Transformer.md - 学习 ViT


恭喜完成第 11 章! 🎉

⚠️ 核验说明(2026-04-03):本页已完成逐段人工复核,并为 VAE/GAN 示例补充了损失函数适用前提与空数据加载器检查。若文中涉及外部模型、API、版本号、价格、部署依赖或第三方产品名称,请以官方文档、论文原文和实际运行环境为准。


最后更新日期: 2026-04-03