跳转至

03 - 变分推断基础

学习时间: 4 小时 重要性: ⭐⭐⭐⭐⭐ 理解扩散模型训练目标的核心数学工具


🎯 学习目标

完成本章后,你将能够: - 理解变分推断的核心思想和数学原理 - 掌握证据下界( ELBO )的完整推导 - 理解均值场近似和重参数化技巧 - 将变分推断应用于扩散模型


1. 变分推断概述

1.1 问题背景

在概率模型中,我们经常遇到后验分布计算困难的问题:

\[p(z | x) = \frac{p(x | z) p(z)}{p(x)}\]

其中 \(p(x) = \int p(x | z) p(z) dz\)边缘似然( evidence ),通常难以计算。

变分推断的核心思想:用一个简单的分布 \(q(z)\) 来近似复杂的后验 \(p(z | x)\)

1.2 为什么需要变分推断

场景 1 :贝叶斯神经网络 - 需要计算权重的后验分布 - 积分空间维度极高(数百万参数)

场景 2 :生成模型 - VAE 、扩散模型都需要近似后验 - 直接计算不可行

场景 3 :概率图模型 - 复杂图结构导致推断困难 - 需要近似算法

1.3 变分推断 vs MCMC

方法 优点 缺点 适用场景
变分推断 快、可扩展、确定性 有偏、近似误差 大规模数据、实时应用
MCMC 渐近精确 慢、难以收敛 小规模数据、精确推断

扩散模型选择变分推断,因为: 1. 需要高效训练 2. 可以接受近似 3. 与深度学习框架兼容


2. KL 散度与变分目标

2.1 KL 散度的定义

KL 散度衡量两个分布之间的差异:

\[D_{KL}(q(z) \| p(z | x)) = \int q(z) \log \frac{q(z)}{p(z | x)} dz\]

性质: - \(D_{KL} \geq 0\),当且仅当 \(q = p\) 时等于 0 - 不对称\(D_{KL}(q \Vert p) \neq D_{KL}(p \| q)\) - 变分推断通常最小化 \(D_{KL}(q \Vert p)\)(前向 KL )

2.2 推导证据下界( ELBO )

我们的目标是最小化 \(D_{KL}(q(z) \| p(z | x))\),但 \(p(z | x)\) 未知。展开 KL 散度:

\[ \begin{aligned} D_{KL}(q(z) \| p(z | x)) &= \mathbb{E}_{q(z)} \left[ \log q(z) - \log p(z | x) \right] \\ &= \mathbb{E}_{q(z)} \left[ \log q(z) - \log \frac{p(x, z)}{p(x)} \right] \\ &= \mathbb{E}_{q(z)} \left[ \log q(z) - \log p(x, z) + \log p(x) \right] \\ &= \mathbb{E}_{q(z)} \left[ \log q(z) - \log p(x, z) \right] + \log p(x) \end{aligned} \]

重新整理:

\[\log p(x) = D_{KL}(q(z) \| p(z | x)) + \mathbb{E}_{q(z)} \left[ \log p(x, z) - \log q(z) \right]\]

由于 \(D_{KL} \geq 0\),我们得到:

\[\log p(x) \geq \mathbb{E}_{q(z)} \left[ \log p(x, z) - \log q(z) \right] =: \mathcal{L}(q)\]

这就是证据下界( Evidence Lower BOund, ELBO )

2.3 ELBO 的等价形式

ELBO 有多种等价表达形式:

形式 1: $\(\mathcal{L}(q) = \mathbb{E}_{q(z)} [\log p(x | z)] - D_{KL}(q(z) \| p(z))\)$

  • 第一项:重构项,衡量解码质量
  • 第二项:正则项,使 \(q(z)\) 接近先验 \(p(z)\)

形式 2: $\(\mathcal{L}(q) = \log p(x) - D_{KL}(q(z) \| p(z | x))\)$

这表明: - 最大化 ELBO 等价于最小化 \(D_{KL}(q(z) \| p(z | x))\) - ELBO 越紧(接近 \(\log p(x)\)),近似越好

2.4 代码验证

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def kl_divergence_gaussian(mu1, logvar1, mu2, logvar2):
    """
    计算两个高斯分布之间的KL散度

    参数:
        mu1, logvar1: 分布q的均值和对数方差
        mu2, logvar2: 分布p的均值和对数方差

    返回:
        KL(q||p)
    """
    var1 = torch.exp(logvar1)
    var2 = torch.exp(logvar2)

    kl = 0.5 * (logvar2 - logvar1 + (var1 + (mu1 - mu2)**2) / var2 - 1)
    return kl.sum(dim=-1)

def compute_elbo(recon_x, x, mu, logvar):
    """
    计算VAE的ELBO

    参数:
        recon_x: 重构的图像
        x: 原始图像
        mu, logvar: 编码器输出的均值和对数方差

    返回:
        ELBO损失
    """
    # 重构损失(二元交叉熵或MSE)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum') / x.shape[0]  # F.xxx PyTorch函数式API

    # KL散度
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.shape[0]

    # ELBO = - (重构损失 + KL散度)
    elbo = -(recon_loss + kl_loss)

    return elbo, recon_loss, kl_loss

# 测试
print("=" * 60)
print("ELBO计算测试")
print("=" * 60)

batch_size = 4
latent_dim = 10
image_size = 32

# 模拟数据
x = torch.randn(batch_size, 3, image_size, image_size)
recon_x = torch.randn(batch_size, 3, image_size, image_size)
mu = torch.randn(batch_size, latent_dim)
logvar = torch.randn(batch_size, latent_dim)

elbo, recon_loss, kl_loss = compute_elbo(recon_x, x, mu, logvar)

print(f"重构损失: {recon_loss.item():.4f}")  # 将单元素张量转为Python数值
print(f"KL散度: {kl_loss.item():.4f}")
print(f"ELBO: {elbo.item():.4f}")

3. 均值场近似

3.1 什么是均值场近似

均值场近似( Mean Field Approximation )假设变分分布可以分解为独立因子的乘积:

\[q(z) = \prod_{i=1}^M q_i(z_i)\]

这意味着我们假设隐变量的不同维度之间相互独立。

3.2 坐标上升变分推断( CAVI )

目标:找到最优的 \(q_i(z_i)\) 来最大化 ELBO 。

推导

对于每个因子 \(q_j(z_j)\),固定其他因子,最优解为:

\[\log q_j^*(z_j) = \mathbb{E}_{-j} [\log p(x, z)] + \text{const}\]

其中 \(\mathbb{E}_{-j}\) 表示对除 \(z_j\) 外的所有变量求期望。

算法流程

Text Only
算法: CAVI
─────────────────────────────────
初始化所有 q_i(z_i)

重复直到收敛:
  对于每个 j = 1, ..., M:
    计算: log q_j(z_j) = E_{-j}[log p(x, z)] + const
    归一化得到 q_j(z_j)

  计算 ELBO

3.3 高斯均值场

假设每个 \(q_i(z_i)\) 是高斯分布:

\[q_i(z_i) = \mathcal{N}(z_i; \mu_i, \sigma_i^2)\]

优化参数:均值 \(\mu_i\) 和方差 \(\sigma_i^2\)(或对数方差)。

梯度

\[\frac{\partial \mathcal{L}}{\partial \mu_i} = \mathbb{E}_{q} \left[ \frac{\partial \log p(x, z)}{\partial z_i} \right]\]
\[\frac{\partial \mathcal{L}}{\partial \sigma_i} = \mathbb{E}_{q} \left[ \frac{\partial \log p(x, z)}{\partial z_i} \cdot \frac{z_i - \mu_i}{\sigma_i} \right] + \frac{1}{\sigma_i}\]

4. 重参数化技巧( Reparameterization Trick )

4.1 问题:随机节点的梯度

在变分推断中,我们需要从 \(q_\phi(z)\) 采样,然后计算梯度:

\[\nabla_\phi \mathbb{E}_{q_\phi(z)} [f(z)]\]

问题:采样操作是不可导的!

4.2 解决方案

对于高斯分布 \(q_\phi(z) = \mathcal{N}(z; \mu_\phi(x), \sigma_\phi^2(x))\),我们可以:

\[z = \mu_\phi(x) + \sigma_\phi(x) \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]

这样:

\[\mathbb{E}_{q_\phi(z)} [f(z)] = \mathbb{E}_{\mathcal{N}(\epsilon; 0, I)} [f(\mu_\phi(x) + \sigma_\phi(x) \cdot \epsilon)]\]

梯度可以通过 \(\mu_\phi\)\(\sigma_\phi\) 传播!

4.3 一般形式

对于任意分布 \(q_\phi(z)\),如果存在变换:

\[z = g_\phi(\epsilon, x), \quad \epsilon \sim p(\epsilon)\]

则:

\[\mathbb{E}_{q_\phi(z)} [f(z)] = \mathbb{E}_{p(\epsilon)} [f(g_\phi(\epsilon, x))]\]

常见分布的重参数化

分布 重参数化
高斯 \(z = \mu + \sigma \cdot \epsilon, \epsilon \sim \mathcal{N}(0, I)\)
指数 \(z = -\log(1 - \epsilon) / \lambda, \epsilon \sim \text{Uniform}(0, 1)\)
Gumbel \(z = \mu - \beta \log(-\log \epsilon), \epsilon \sim \text{Uniform}(0, 1)\)

4.4 代码实现

Python
class ReparameterizedGaussian:
    """
    可重参数化的高斯分布
    """
    def __init__(self, mu, logvar):
        self.mu = mu
        self.logvar = logvar
        self.std = torch.exp(0.5 * logvar)

    def sample(self, num_samples=None):
        """
        重参数化采样
        """
        if num_samples is None:
            eps = torch.randn_like(self.std)
        else:
            eps = torch.randn(num_samples, *self.std.shape)
        return self.mu + self.std * eps

    def rsample(self, num_samples=None):
        """
        可微采样(PyTorch风格)
        """
        return self.sample(num_samples)

    def log_prob(self, z):
        """
        计算对数概率
        """
        var = torch.exp(self.logvar)
        log_prob = -0.5 * (
            torch.log(2 * torch.pi * var) +
            (z - self.mu).pow(2) / var
        )
        return log_prob.sum(dim=-1)

    def kl_divergence(self, other_mu=0, other_logvar=0):
        """
        计算与标准高斯的KL散度
        """
        return kl_divergence_gaussian(
            self.mu, self.logvar,
            torch.zeros_like(self.mu),
            torch.zeros_like(self.logvar)
        )

# 测试重参数化
print("\n" + "=" * 60)
print("重参数化技巧测试")
print("=" * 60)

mu = torch.tensor([0.0, 1.0, -1.0])
logvar = torch.tensor([0.0, 0.5, -0.5])

dist = ReparameterizedGaussian(mu, logvar)

# 多次采样
samples = []
for _ in range(1000):
    z = dist.sample()
    samples.append(z)

samples = torch.stack(samples)  # torch.stack沿新维度拼接张量

print(f"理论均值: {mu}")
print(f"采样均值: {samples.mean(dim=0)}")
print(f"理论标准差: {torch.exp(0.5 * logvar)}")
print(f"采样标准差: {samples.std(dim=0)}")

# 验证梯度
mu_param = torch.tensor([0.0, 1.0], requires_grad=True)
logvar_param = torch.tensor([0.0, 0.0], requires_grad=True)

dist_grad = ReparameterizedGaussian(mu_param, logvar_param)
z = dist_grad.sample()
loss = z.pow(2).sum()
loss.backward()  # 反向传播计算梯度

print(f"\n梯度验证:")
print(f"mu梯度: {mu_param.grad}")
print(f"logvar梯度: {logvar_param.grad}")

5. 变分自编码器( VAE )

5.1 VAE 作为变分推断

VAE 是变分推断在深度学习中的典型应用:

生成模型: $\(p_\theta(x, z) = p_\theta(x | z) p(z)\)$

推断模型(编码器): $\(q_\phi(z | x) = \mathcal{N}(z; \mu_\phi(x), \sigma_\phi^2(x))\)$

ELBO: $\(\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z | x)} [\log p_\theta(x | z)] - D_{KL}(q_\phi(z | x) \| p(z))\)$

5.2 与扩散模型的联系

特性 VAE 扩散模型
隐变量 低维 \(z\) 高维 \(x_{1:T}\)
编码器 学习 \(q_\phi(z \mid x)\) 固定前向过程
解码器 学习 \(p_\theta(x \mid z)\) 学习反向过程
推断 单步 多步马尔可夫链

扩散模型可以看作: - 编码器是固定的(前向扩散) - 解码器是多步的(反向去噪) - 隐变量与数据同维度


6. 变分推断在扩散模型中的应用

6.1 扩散模型的变分目标

回顾扩散模型的联合分布:

前向(固定): $\(q(x_{0:T}) = q(x_0) \prod_{t=1}^T q(x_t | x_{t-1})\)$

反向(学习): $\(p_\theta(x_{0:T}) = p(x_T) \prod_{t=1}^T p_\theta(x_{t-1} | x_t)\)$

变分下界: $\(\log p_\theta(x_0) \geq \mathbb{E}_{q(x_{1:T} | x_0)} \left[ \log \frac{p_\theta(x_{0:T})}{q(x_{1:T} | x_0)} \right]\)$

6.2 简化为去噪目标

经过推导(详见第 4 章), ELBO 可以简化为:

\[\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right]\]

这就是去噪目标

直观理解: - 变分推断告诉我们应该优化什么( ELBO ) - 数学推导简化为具体的目标函数 - 最终形式简单但有效

6.3 为什么扩散模型有效

  1. 灵活的近似族:高斯分布可以近似复杂的后验
  2. 多步细化:每一步只需要做小的修正
  3. 重参数化:可以使用梯度下降高效训练
  4. 变分下界:提供了理论保证

7. 高级主题

7.1 重要性加权自编码器( IWAE )

使用多个样本获得更紧的下界:

\[\mathcal{L}_k = \mathbb{E}_{z_1, ..., z_k \sim q(z | x)} \left[ \log \frac{1}{k} \sum_{i=1}^k \frac{p(x, z_i)}{q(z_i | x)} \right]\]

\(k \to \infty\) 时,\(\mathcal{L}_k \to \log p(x)\)

7.2 归一化流( Normalizing Flows )

使用可逆变换构建复杂的变分分布:

\[z = f_\phi(\epsilon), \quad \epsilon \sim \mathcal{N}(0, I)\]
\[q_\phi(z) = p(\epsilon) \left| \det \frac{\partial f_\phi^{-1}}{\partial z} \right|\]

7.3 变分推断的局限性

  1. 近似误差\(q(z)\) 可能无法很好地近似 \(p(z | x)\)
  2. 局部最优:优化可能陷入局部最优
  3. 模型选择:选择合适的 \(q(z)\) 族很重要
  4. 计算成本:对于复杂模型,计算成本仍然很高

8. 本章总结

核心概念

  1. 变分推断
  2. 用简单分布近似复杂后验
  3. 最小化 KL 散度
  4. 最大化证据下界( ELBO )

  5. ELBO

  6. \(\mathcal{L} = \mathbb{E}_q[\log p(x, z) - \log q(z)]\)
  7. 重构项 + 正则项
  8. 提供了似然的下界

  9. 重参数化技巧

  10. 使随机采样可微
  11. 高斯:\(z = \mu + \sigma \cdot \epsilon\)
  12. 支持端到端训练

  13. 与扩散模型的联系

  14. 扩散模型是多步 VAE
  15. 前向过程是固定编码器
  16. 反向过程是学习解码器

关键公式

概念 公式
KL 散度 \(D_{KL}(q \Vert p) = \mathbb{E}_q[\log q - \log p]\)
ELBO \(\mathcal{L} = \mathbb{E}_q[\log p(x, z) - \log q(z)]\)
高斯 ELBO \(\mathcal{L} = \mathbb{E}[\log p(x \| z)] - D_{KL}(q(z) \| p(z))\)
重参数化 \(z = \mu + \sigma \cdot \epsilon\)

算法流程

Text Only
变分推断算法:
1. 选择变分分布族 q_φ(z)
2. 初始化参数 φ
3. 重复:
   a. 从 q_φ(z) 采样 z
   b. 计算 ELBO
   c. 计算梯度 ∇_φ ELBO
   d. 更新参数 φ
4. 直到收敛

📝 自测问题

基础问题

  1. 变分推断基础
  2. 为什么需要变分推断?
  3. KL 散度为什么不对称?
  4. 前向 KL 和后向 KL 有什么区别?

  5. ELBO 推导

  6. 从 KL 散度推导出 ELBO
  7. ELBO 的两项分别代表什么?
  8. 为什么 ELBO 是似然的下界?

  9. 重参数化技巧

  10. 为什么需要重参数化?
  11. 高斯分布如何重参数化?
  12. 重参数化对梯度计算有什么帮助?

编程练习

  1. 实现完整的 VAE 模型
  2. 比较不同样本数下的 ELBO
  3. 实现重要性加权自编码器( IWAE )
  4. 可视化变分分布的优化过程

思考题

  1. 变分推断的近似误差来自哪里?
  2. 如何设计更好的变分分布?
  3. 扩散模型相比 VAE 的优势在哪里?
  4. 变分推断和 MCMC 如何选择?

🔗 下一步

理解了变分推断后,我们将学习线性代数基础,这是理解扩散模型的数学工具。

→ 下一步:04-线性代数基础.md