跳转至

07 - 连续时间扩散模型

学习时间: 5 小时 重要性: ⭐⭐⭐⭐⭐ 从离散到连续,理解扩散模型的统一框架


🎯 学习目标

完成本章后,你将能够: - 理解从离散时间到连续时间的过渡 - 掌握 VP-SDE 、 VE-SDE 和子 VP-SDE 的区别 - 理解得分随机微分方程( Score SDE ) - 掌握概率流 ODE 和似然计算 - 实现连续时间扩散模型


1. 从离散到连续

1.1 离散时间扩散的局限

回顾 DDPM 的离散时间前向过程: $\(x_{t} = \sqrt{1-\beta_t} x_{t-1} + \sqrt{\beta_t} \epsilon_{t-1}\)$

问题: - 时间步数 \(T\) 是离散的(通常 1000 步) - 步长 \(\beta_t\) 需要精心设计 - 难以分析连续时间极限

1.2 连续时间极限

令时间步长 \(\Delta t \to 0\),离散过程收敛到随机微分方程( SDE )

\[dx = f(x, t) dt + g(t) dw\]

其中: - \(f(x, t)\):漂移系数 - \(g(t)\):扩散系数 - \(w\):维纳过程(标准布朗运动)

1.3 连续时间的优势

  1. 数学优雅:可以使用随机分析工具
  2. 统一框架:涵盖多种扩散模型
  3. 灵活采样:任意时间步长
  4. 精确似然:通过概率流 ODE 计算

2. 三种经典 SDE

2.1 VP-SDE ( Variance Preserving SDE )

动机:保持方差有界,避免爆炸或消失

SDE 形式: $\(dx = -\frac{1}{2}\beta(t) x dt + \sqrt{\beta(t)} dw\)$

扰动核: $\(p(x(t) | x(0)) = \mathcal{N}(x(t); e^{-\frac{1}{2}\int_0^t \beta(s)ds} x(0), (1 - e^{-\int_0^t \beta(s)ds})I)\)$

特性: - 当 \(t \to \infty\)\(x(t) \sim \mathcal{N}(0, I)\) - 方差保持在 \([0, 1]\) 范围内 - 对应于 DDPM 的连续版本

2.2 VE-SDE ( Variance Exploding SDE )

动机:允许方差随时间增长

SDE 形式: $\(dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}} dw\)$

扰动核: $\(p(x(t) | x(0)) = \mathcal{N}(x(t); x(0), (\sigma^2(t) - \sigma^2(0))I)\)$

特性: - 方差随时间增长( exploding ) - 需要选择适当的 \(\sigma(t)\) 调度 - 对应于 SMLD ( Score Matching with Langevin Dynamics )

2.3 子 VP-SDE ( Sub-Variance Preserving SDE )

动机: VP-SDE 的改进版本,更好的数值稳定性

SDE 形式: $\(dx = -\frac{1}{2}\beta(t) x dt + \sqrt{\beta(t)(1 - e^{-2\int_0^t \beta(s)ds})} dw\)$

特性: - 介于 VP-SDE 和 VE-SDE 之间 - 在 Score SDE 论文中表现最好 - 更好的数值稳定性

2.4 三种 SDE 对比

特性 VP-SDE VE-SDE 子 VP-SDE
漂移
方差 保持 爆炸 次保持
稳定性 需调参 最好
对应离散模型 DDPM SMLD 改进版

3. 得分随机微分方程( Score SDE )

3.1 统一框架

一般形式的前向 SDE: $\(dx = f(x, t) dt + g(t) dw\)$

关键洞察:通过Anderson 定理,可以推导出对应的逆向 SDE

\[dx = [f(x, t) - g^2(t) \nabla_x \log p_t(x)] dt + g(t) d\bar{w}\]

其中: - \(\nabla_x \log p_t(x)\)得分函数( score function ) - \(d\bar{w}\) 是反向时间的维纳过程

3.2 得分函数估计

在实际中,我们用神经网络 \(s_\theta(x, t)\) 来近似得分函数:

\[s_\theta(x, t) \approx \nabla_x \log p_t(x)\]

训练目标(去噪得分匹配): $\(\mathcal{L}(\theta) = \mathbb{E}_{t, x(0), x(t)} \left[ \| s_\theta(x(t), t) - \nabla_{x(t)} \log p(x(t) | x(0)) \|^2 \right]\)$

为什么去噪得分匹配等价于显式得分匹配?

显式得分匹配( ESM )的目标是最小化 \(\mathbb{E}_{p_t(x)}[\|s_\theta(x,t) - \nabla_x \log p_t(x)\|^2]\),但 \(\nabla_x \log p_t(x)\) 未知。关键等式为:

\[\mathbb{E}_{p_t(x)}\left[\|s_\theta - \nabla_x \log p_t\|^2\right] = \mathbb{E}_{p_0(x_0)p_{t|0}(x_t|x_0)}\left[\|s_\theta(x_t,t) - \nabla_{x_t} \log p_{t|0}(x_t|x_0)\|^2\right] + C\]

其中 \(C\) 不依赖 \(\theta\)。证明思路:将 \(\nabla_x \log p_t(x)\) 展开为 \(\nabla_x \log \int p_{t|0}(x|x_0)p_0(x_0)dx_0\),利用 \(\nabla_x \log p_t(x) = \mathbb{E}_{p_0(x_0|x)}[\nabla_x \log p_{t|0}(x|x_0)]\)(后验加权平均),然后展开平方范数并利用该恒等式消去交叉项中的未知量。最终两个目标关于 \(\theta\) 的梯度相同( Vincent, 2011 )。

因此,我们可以用已知的条件得分 \(\nabla_{x_t} \log p(x_t|x_0)\) 替代未知的边缘得分进行训练。对于高斯扰动核 \(p(x_t|x_0) = \mathcal{N}(\alpha_t x_0, \sigma_t^2 I)\),条件得分为 \(-\frac{x_t - \alpha_t x_0}{\sigma_t^2}\)

3.3 与 DDPM 的联系

对于 VP-SDE ,得分函数与 DDPM 的噪声预测的关系:

\[s_\theta(x, t) = -\frac{\epsilon_\theta(x, t)}{\sqrt{1 - \alpha(t)}}\]

其中 \(\alpha(t) = e^{-\int_0^t \beta(s)ds}\)


4. 概率流 ODE

4.1 确定性生成

定理:存在一个常微分方程( ODE ),其边缘分布与 SDE 相同:

\[\frac{dx}{dt} = f(x, t) - \frac{1}{2}g^2(t) s_\theta(x, t)\]

这就是概率流 ODE ( Probability Flow ODE )

4.2 ODE vs SDE

特性 SDE ODE
随机性
可逆性 概率可逆 确定性可逆
似然计算 困难 容易(用流模型技术)
采样速度

4.3 精确似然计算

通过概率流 ODE ,可以使用瞬时变化率公式计算精确似然:

\[\log p(x(0)) = \log p(x(T)) + \int_0^T \nabla \cdot \tilde{f}(x(t), t) dt\]

其中 \(\tilde{f}(x, t) = f(x, t) - \frac{1}{2}g^2(t) s_\theta(x, t)\)

散度计算: - 精确方法:自动微分(计算量大) - 近似方法: Skilling-Hutchinson 迹估计


5. 数值方法

5.1 SDE 求解器

Euler-Maruyama 方法: $\(x_{t+\Delta t} = x_t + f(x_t, t)\Delta t + g(t) \sqrt{\Delta t} z_t, \quad z_t \sim \mathcal{N}(0, I)\)$

Milstein 方法(更高精度): 增加二阶修正项

预测-校正方法: 1. 预测步:用 Euler-Maruyama 得到 \(\tilde{x}_{t+\Delta t}\) 2. 校正步:用得分函数修正

5.2 ODE 求解器

Runge-Kutta 方法: - RK45 :自适应步长 - RK4 :固定步长

Adams-Bashforth 方法: 多步法,利用历史信息

5.3 步长选择

自适应步长: 根据局部误差估计调整步长

固定步长: 均匀离散化,简单易实现


6. 实现

Python
import torch
import torch.nn as nn
import numpy as np
from scipy import integrate

class ContinuousDiffusion:
    """
    连续时间扩散模型
    """
    def __init__(self, sde_type='vp', beta_min=0.1, beta_max=20.0):
        """
        参数:
            sde_type: 'vp', 've', 或 'subvp'
            beta_min, beta_max: beta调度参数
        """
        self.sde_type = sde_type
        self.beta_min = beta_min
        self.beta_max = beta_max

    def beta(self, t):
        """beta(t)调度"""
        return self.beta_min + t * (self.beta_max - self.beta_min)

    def drift(self, x, t):
        """漂移系数 f(x, t)"""
        if self.sde_type == 'vp':
            return -0.5 * self.beta(t) * x
        elif self.sde_type == 've':
            return torch.zeros_like(x)
        elif self.sde_type == 'subvp':
            return -0.5 * self.beta(t) * x
        else:
            raise ValueError(f"Unknown SDE type: {self.sde_type}")

    def diffusion(self, t):
        """扩散系数 g(t)"""
        if self.sde_type == 'vp':
            return np.sqrt(self.beta(t))
        elif self.sde_type == 've':
            sigma = self.sigma(t)
            # g(t) = σ(t)√(2 ln(σ_max/σ_min)),由 d[σ²(t)]/dt 推导
            return sigma * np.sqrt(2 * np.log(self.beta_max / self.beta_min))
        elif self.sde_type == 'subvp':
            alpha = np.exp(-0.5 * self.beta_min * t - 0.25 * (self.beta_max - self.beta_min) * t**2)
            # g(t) = √(β(t)(1 - e^{-2∫β})),其中 e^{-2∫β} = α⁴
            return np.sqrt(self.beta(t) * (1 - alpha**4))
        else:
            raise ValueError(f"Unknown SDE type: {self.sde_type}")

    def sigma(self, t):
        """VE-SDE的sigma(t)"""
        return self.beta_min * (self.beta_max / self.beta_min) ** t

    def marginal_prob(self, x0, t):
        """
        计算边缘分布 p(x(t) | x(0))

        返回:
            mean: 均值
            std: 标准差
        """
        if self.sde_type == 'vp':
            alpha = np.exp(-0.5 * self.beta_min * t - 0.25 * (self.beta_max - self.beta_min) * t**2)
            mean = alpha * x0
            std = np.sqrt(1 - alpha**2)
        elif self.sde_type == 've':
            mean = x0
            # 方差为 σ²(t) - σ²(0),标准差取平方根
            std = np.sqrt(self.sigma(t)**2 - self.sigma(0)**2)
        elif self.sde_type == 'subvp':
            alpha = np.exp(-0.5 * self.beta_min * t - 0.25 * (self.beta_max - self.beta_min) * t**2)
            mean = alpha * x0
            # sub-VP-SDE: 方差 = (1-α²)²,标准差 = |1-α²| = 1-α² (因为α∈[0,1])
            # 注意:这里std是标准差,直接等于(1-α²)
            std = 1 - alpha**2
        else:
            raise ValueError(f"Unknown SDE type: {self.sde_type}")

        return mean, std

    def prior_sampling(self, shape):
        """从先验分布采样"""
        if self.sde_type == 'vp' or self.sde_type == 'subvp':
            return torch.randn(*shape)
        elif self.sde_type == 've':
            return torch.randn(*shape) * self.sigma(1.0)

    def forward_sde(self, x0, t, noise=None):
        """
        前向SDE: x(t) = mean + std * noise
        """
        mean, std = self.marginal_prob(x0, t)
        if noise is None:
            noise = torch.randn_like(x0)
        xt = mean + std * noise
        return xt, noise

    def reverse_sde(self, score_model, xt, t, dt, noise=None):
        """
        逆向SDE采样一步

        参数:
            score_model: 得分模型 s_θ(x, t)
            xt: 当前状态
            t: 当前时间
            dt: 时间步长(负值表示反向)
            noise: 可选的噪声
        """
        if noise is None:
            noise = torch.randn_like(xt)

        # 计算得分
        score = score_model(xt, t)

        # 漂移和扩散
        f = self.drift(xt, t)
        g = self.diffusion(t)

        # 逆向SDE: dx = (f - g² * score)dt + g * dw
        drift = f - g**2 * score
        diffusion = g

        # Euler-Maruyama
        x_prev = xt + drift * dt + diffusion * np.sqrt(-dt) * noise

        return x_prev

    def probability_flow_ode(self, score_model, xt, t, dt):
        """
        概率流ODE采样一步

        参数:
            score_model: 得分模型
            xt: 当前状态
            t: 当前时间
            dt: 时间步长(负值表示反向)
        """
        # 计算得分
        score = score_model(xt, t)

        # 漂移和扩散
        f = self.drift(xt, t)
        g = self.diffusion(t)

        # 概率流ODE: dx = (f - 0.5 * g² * score)dt
        dx = (f - 0.5 * g**2 * score) * dt

        x_prev = xt + dx

        return x_prev

    def sample_sde(self, score_model, shape, device='cuda', num_steps=1000, eps=1e-3):
        """
        使用SDE采样
        """
        score_model.eval()  # eval()评估模式

        # 从先验采样
        x = self.prior_sampling(shape).to(device)  # 移至GPU/CPU

        # 时间网格
        timesteps = torch.linspace(1.0, eps, num_steps)
        dt = -(1.0 - eps) / num_steps

        # 逆向采样
        with torch.no_grad():  # 禁用梯度计算,节省内存
            for i in range(num_steps):
                t = timesteps[i]
                t_batch = torch.ones(shape[0], device=device) * t
                x = self.reverse_sde(score_model, x, t_batch, dt)

        return x

    def sample_ode(self, score_model, shape, device='cuda', rtol=1e-5, atol=1e-5, method='RK45'):
        """
        使用ODE采样(自适应步长)
        """
        score_model.eval()

        # 初始条件
        x = self.prior_sampling(shape).to(device)

        # 定义ODE
        def ode_func(t, x):
            x = torch.tensor(x, device=device, dtype=torch.float32).reshape(shape)  # 重塑张量形状
            t = torch.ones(shape[0], device=device) * t

            with torch.no_grad():
                score = score_model(x, t)
                f = self.drift(x, t.item())  # 将单元素张量转为Python数值
                g = self.diffusion(t.item())
                dx = (f - 0.5 * g**2 * score).cpu().numpy().flatten()

            return dx

        # 求解ODE
        solution = integrate.solve_ivp(
            ode_func,
            [1.0, 1e-3],
            x.cpu().numpy().flatten(),
            rtol=rtol,
            atol=atol,
            method=method,
        )

        x_final = torch.tensor(solution.y[:, -1], device=device, dtype=torch.float32).reshape(shape)

        return x_final

# 得分模型示例
class ScoreNet(nn.Module):  # 继承nn.Module定义网络层
    """简化的得分网络"""
    def __init__(self, channels=3, time_emb_dim=256):
        super().__init__()  # super()调用父类方法
        self.time_embed = nn.Sequential(
            nn.Linear(1, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

        # 简化的UNet结构
        self.encoder = nn.Sequential(
            nn.Conv2d(channels, 64, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.SiLU(),
        )

        self.middle = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1),
            nn.SiLU(),
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(64, channels, 3, padding=1),
        )

    def forward(self, x, t):
        # 时间嵌入
        if t.dim() == 0:
            t = t.unsqueeze(0).expand(x.shape[0])  # unsqueeze增加一个维度
        t_emb = self.time_embed(t.view(-1, 1).float())

        # 前向传播
        h = self.encoder(x)
        h = self.middle(h)
        h = self.decoder(h)

        return h

# 使用示例
if __name__ == "__main__":
    # 创建扩散模型
    diffusion = ContinuousDiffusion(sde_type='vp', beta_min=0.1, beta_max=20.0)

    # 创建得分模型
    score_model = ScoreNet(channels=3)

    # 测试前向过程
    x0 = torch.randn(4, 3, 32, 32)
    t = torch.tensor([0.5, 0.5, 0.5, 0.5])
    xt, noise = diffusion.forward_sde(x0, t)

    print(f"x0 shape: {x0.shape}")
    print(f"xt shape: {xt.shape}")
    print(f"noise shape: {noise.shape}")

    # 测试采样(需要训练好的模型)
    # samples = diffusion.sample_sde(score_model, shape=(4, 3, 32, 32))
    # print(f"Samples shape: {samples.shape}")

7. 本章总结

核心概念

  1. 连续时间扩散
  2. 从离散 SDE 到连续 SDE
  3. 数学更优雅,分析更方便
  4. 统一框架

  5. 三种 SDE

  6. VP-SDE :方差保持( DDPM )
  7. VE-SDE :方差爆炸( SMLD )
  8. 子 VP-SDE :改进版本

  9. Score SDE

  10. 用神经网络估计得分函数
  11. 逆向 SDE 进行采样
  12. 统一训练目标

  13. 概率流 ODE

  14. 确定性采样
  15. 精确似然计算
  16. 更快的采样

关键公式

概念 公式
VP-SDE \(dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)} dw\)
VE-SDE \(dx = \sqrt{d[\sigma^2(t)]/dt} dw\)
逆向 SDE \(dx = (f - g^2 \nabla \log p)dt + g d\bar{w}\)
概率流 ODE \(\frac{dx}{dt} = f - \frac{1}{2}g^2 \nabla \log p\)

实现要点

Python
# 核心采样循环
def sample_sde(score_model, shape):
    x = prior_sampling(shape)
    for t in reversed(timesteps):
        score = score_model(x, t)
        x = reverse_sde_step(x, t, score)
    return x

📝 自测问题

基础问题

  1. 连续时间优势
  2. 为什么需要连续时间扩散模型?
  3. 离散和连续的主要区别?
  4. 连续时间的数学工具?

  5. 三种 SDE

  6. VP-SDE 、 VE-SDE 、子 VP-SDE 的区别?
  7. 各自的优势和适用场景?
  8. 如何选择合适的 SDE ?

  9. Score SDE

  10. 什么是得分函数?
  11. 如何训练得分模型?
  12. 与 DDPM 的关系?

  13. 概率流 ODE

  14. ODE 与 SDE 的区别?
  15. 如何计算精确似然?
  16. 数值求解方法?

编程练习

  1. 实现三种 SDE 的前向过程
  2. 实现逆向 SDE 采样
  3. 实现概率流 ODE 采样
  4. 比较 SDE 和 ODE 的采样质量

思考题

  1. 连续时间模型的计算挑战?
  2. 如何设计更好的 SDE ?
  3. 概率流 ODE 的局限性?

🔗 下一步

理解了连续时间扩散模型后,我们将进入扩散模型变体与进阶,学习 DDIM 加速采样等高级技术。

→ 下一步:04-扩散模型变体与进阶/01-DDIM 加速采样.md