跳转至

07 - DiT 与 Transformer 扩散架构

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习时间: 4 小时 重要性: ⭐⭐⭐⭐⭐ Sora/SD3/FLUX 等最新模型的共同基础架构


🎯 学习目标

完成本章后,你将能够: - 理解扩散模型从 U-Net 到 Transformer 的范式转变动因 - 掌握 DiT 的核心设计: Patchify 、 AdaLN-Zero 条件注入 - 了解 U-ViT 、 SiT 、 MDT 、 MM-DiT 等 Transformer 扩散变体 - 理解 Scaling Laws 在扩散 Transformer 中的体现 - 实现一个简化版 DiT 模型


1. 从 U-Net 到 Transformer 的范式转变

1.1 为什么要替换 U-Net

U-Net 自 DDPM 以来一直是扩散模型的标准骨干网络,但它存在固有局限:

维度 U-Net Transformer
归纳偏置 强空间局部性偏置 弱偏置,更灵活
可扩展性 难以单纯堆叠层数提升 Scaling Laws 明确
多模态融合 Cross-attention 拼接 天然支持序列混合
分辨率泛化 需要特殊处理 位置编码灵活适配
工程生态 定制化结构 复用 LLM 训练基础设施

关键动因:随着模型规模增大, Transformer 展现出更明确的 Scaling Law 特性——模型越大、数据越多、效果越好,且改进趋势可预测。这与 LLM 的成功经验一致。

1.2 Transformer 在视觉任务中的发展

Text Only
ViT (2020) → 图像分类
DALL-E (2021) → 自回归图像生成
U-ViT (2023) → 将Transformer嵌入U-Net结构
DiT (2023) → 纯Transformer扩散模型
SD3/FLUX/Sora (2024) → 大规模Transformer扩散模型

2. DiT 原理详解

2.1 论文概述

DiT ( Peebles & Xie, 2023, "Scalable Diffusion Models with Transformers")首次证明了纯 Transformer 架构在扩散模型中可以取代 U-Net ,并展现出清晰的 Scaling Laws 。

2.2 Patchify :图像到序列的转换

DiT 遵循 ViT 的方式,将图像(或潜空间特征)分割为 patch 序列:

\[\text{Input: } z \in \mathbb{R}^{C \times H \times W} \xrightarrow{\text{Patchify}} X \in \mathbb{R}^{N \times D}\]

其中 \(N = \frac{H \times W}{p^2}\) 为 patch 数量,\(p\) 为 patch 大小,\(D\) 为 embedding 维度。

Python
import torch
import torch.nn as nn

class PatchEmbed(nn.Module):  # 继承nn.Module定义网络层
    """将潜空间特征图转换为patch序列"""

    def __init__(self, img_size=32, patch_size=2, in_channels=4, embed_dim=768):
        super().__init__()  # super()调用父类方法
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: [B, C, H, W] → [B, D, H/p, W/p] → [B, N, D]
        x = self.proj(x)                    # [B, D, H/p, W/p]
        x = x.flatten(2).transpose(1, 2)    # [B, N, D]
        return x

Patch 大小的影响: - \(p=2\): 序列更长,计算量大,细节更好 - \(p=4\): 序列更短,效率高,适合高分辨率 - \(p=8\): 极端压缩,信息损失较大

2.3 AdaLN-Zero :条件注入机制

DiT 探索了四种将时间步 \(t\) 和类别标签 \(c\) 注入 Transformer Block 的方式:

方式 方法 FID
In-context \(t, c\) 作为额外 token 拼接到序列中 较高
Cross-attention \(t, c\) 作为 cross-attention 的 KV 中等
AdaLN \(t, c\) 预测 LayerNorm 的 \(\gamma, \beta\) 较低
AdaLN-Zero AdaLN + 零初始化 scale 参数 最低

AdaLN-Zero 的核心思想

标准 LayerNorm 对输入进行归一化后,施加可学习的 scale 和 shift : $\(\text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta\)$

AdaLN-Zero 将 \(\gamma, \beta\) 替换为条件预测的参数,并额外引入一个零初始化的门控参数 \(\alpha\)

\[\gamma, \beta, \alpha = \text{MLP}(t_{emb} + c_{emb})\]
\[\text{Block Output} = x + \alpha \cdot \text{Attention}(\text{AdaLN}(x; \gamma, \beta))\]

初始化时 \(\alpha = 0\),使得每个 DiT Block 在训练初期是恒等映射,保证训练稳定性。

Python
class AdaLNZero(nn.Module):
    """Adaptive Layer Norm Zero — DiT的核心条件注入模块"""

    def __init__(self, dim, cond_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)
        # 预测 gamma, beta, alpha 三组参数
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(cond_dim, 6 * dim)  # 为attention和MLP各预测3个参数
        )
        # 零初始化
        nn.init.zeros_(self.adaLN_modulation[-1].weight)  # [-1]负索引取最后元素
        nn.init.zeros_(self.adaLN_modulation[-1].bias)

    def forward(self, x, cond):
        # cond: [B, cond_dim] → 预测6组调制参数
        (gamma1, beta1, alpha1,
         gamma2, beta2, alpha2) = self.adaLN_modulation(cond).chunk(6, dim=-1)

        return gamma1, beta1, alpha1, gamma2, beta2, alpha2

    def modulate(self, x, gamma, beta):
        return self.norm(x) * (1 + gamma.unsqueeze(1)) + beta.unsqueeze(1)  # unsqueeze增加一个维度

2.4 DiT 整体架构

Text Only
输入潜变量 z ∈ R^(C×H×W)
   [PatchEmbed] → [B, N, D]
   + Positional Embedding (sin-cos)
   ┌─────────────────────────┐
   │     DiT Block × L       │
   │ ┌─────────────────────┐ │
   │ │ AdaLN-Zero(t, c)    │ │
   │ │     ↓                │ │
   │ │ Self-Attention       │ │
   │ │     ↓                │ │
   │ │ AdaLN-Zero(t, c)    │ │
   │ │     ↓                │ │
   │ │ Pointwise FFN        │ │
   │ └─────────────────────┘ │
   └─────────────────────────┘
   [Final AdaLN + Linear]
   [Unpatchify] → R^(2C×H×W)  (预测噪声+对角协方差)

3. Transformer 扩散模型家族

3.1 U-ViT

U-ViT ( Bao et al., 2023, "All are Worth Words: A ViT Backbone for Diffusion Models")将 U-Net 的跳跃连接引入 ViT :

  • 将时间步、条件、图像 patches 视为统一的 token 序列
  • 浅层和深层之间引入 long skip connections
  • 保持了 U-Net 跳跃连接的优势,同时获得 Transformer 的灵活性
\[X_{deep} = \text{Block}_{deep}(X_{shallow}) + \text{Linear}(X_{shallow})\]

3.2 SiT ( Scalable Interpolant Transformers )

SiT ( Ma et al., 2024 )结合 DiT 架构与 Flow Matching 训练:

  • 使用随机插值( Stochastic Interpolant )框架统一多种扩散/流匹配目标
  • 在 Interpolant 框架下自由选择不同的时间调度和速度场参数化
  • 同等 FLOPs 下优于 DiT

3.3 MDT ( Masked Diffusion Transformer )

MDT ( Gao et al., 2023, "Masked Diffusion Transformer is a Strong Image Synthesizer")引入了掩码建模思想:

  • 训练时随机 mask 部分 patch tokens
  • 模型同时学习去噪和 patch 预测
  • 加速训练收敛,提升上下文学习能力

3.4 MM-DiT

MM-DiT ( Esser et al., 2024 )是 SD3/SD3.5 使用的多模态 Transformer (详见上一章):

  • 图像和文本 token 联合注意力
  • 模态专属的 QKV 投影和 MLP
  • 支持多种文本编码器的灵活接入

3.5 架构对比

模型 条件注入 文本处理 跳跃连接 训练目标
DiT AdaLN-Zero 类别标签 DDPM \(\epsilon\)-prediction
U-ViT In-context Token 拼接 DDPM \(\epsilon\)-prediction
SiT AdaLN-Zero 类别标签 Flow Matching
MDT AdaLN-Zero 类别标签 DDPM + Mask
MM-DiT AdaLN-Zero 联合注意力 Rectified Flow

4. Scaling Laws 在扩散 Transformer 中的应用

4.1 DiT 的 Scaling 实验

DiT 论文中系统测试了不同规模的模型:

模型 层数 维度 注意力头数 参数量 FID (ImageNet 256)
DiT-S/2 12 384 6 33M 68.4
DiT-B/2 12 768 12 130M 43.5
DiT-L/2 24 1024 16 458M 23.3
DiT-XL/2 28 1152 16 675M 9.62

关键发现: 1. FID 与模型 FLOPs 呈幂律关系\(\text{FID} \propto \text{GFLOPs}^{-\alpha}\) 2. 增大模型比增大 patch 更有效:相同 GFLOPs 下,大模型+大 patch < 大模型+小 patch 3. Scaling 趋势未饱和:暗示更大模型将持续改进

4.2 工业级 Scaling 实践

Sora 、 FLUX 等产品验证了扩散 Transformer 的 Scaling Law :

\[L(N, D) \approx \frac{A}{N^{\alpha_N}} + \frac{B}{D^{\alpha_D}} + L_0\]

其中 \(N\) 为参数量,\(D\) 为数据量,\(L\) 为验证损失。

实际数据: - FLUX.1: 12B 参数 → 目前开源最佳图像质量 - Sora: 估计数十 B 参数 → 突破性视频质量 - 趋势:扩散模型正在重复 LLM 的 scaling 故事


5. 简化版 DiT 代码实现

Python
import torch
import torch.nn as nn
import math

class SinusoidalPosEmb(nn.Module):
    """正弦时间步编码"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None].float() * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # torch.cat沿已有维度拼接张量

class DiTBlock(nn.Module):
    """DiT Transformer Block with AdaLN-Zero"""

    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # Self-Attention
        self.qkv = nn.Linear(dim, 3 * dim)
        self.proj = nn.Linear(dim, dim)

        # FFN
        mlp_hidden = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden),
            nn.GELU(),
            nn.Linear(mlp_hidden, dim),
        )

        # AdaLN-Zero: 两层各需 gamma, beta, alpha → 6 * dim
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
        self.adaLN = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 6 * dim),
        )
        # 关键:零初始化
        nn.init.zeros_(self.adaLN[-1].weight)
        nn.init.zeros_(self.adaLN[-1].bias)

    def forward(self, x, cond):
        """
        x: [B, N, D] — patch token序列
        cond: [B, D] — 条件嵌入(时间步+类别)
        """
        B, N, D = x.shape

        # 预测6组调制参数
        gamma1, beta1, alpha1, gamma2, beta2, alpha2 = \
            self.adaLN(cond).chunk(6, dim=-1)

        # --- Attention分支 ---
        h = self.norm1(x) * (1 + gamma1.unsqueeze(1)) + beta1.unsqueeze(1)
        qkv = self.qkv(h).reshape(B, N, 3, self.num_heads, D // self.num_heads)  # 重塑张量形状
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        h = (attn @ v).transpose(1, 2).reshape(B, N, D)
        h = self.proj(h)
        x = x + alpha1.unsqueeze(1) * h  # 零初始化门控

        # --- FFN分支 ---
        h = self.norm2(x) * (1 + gamma2.unsqueeze(1)) + beta2.unsqueeze(1)
        h = self.mlp(h)
        x = x + alpha2.unsqueeze(1) * h  # 零初始化门控

        return x

class DiT(nn.Module):
    """简化版 Diffusion Transformer"""

    def __init__(
        self,
        img_size=32,
        patch_size=2,
        in_channels=4,
        dim=768,
        depth=12,
        num_heads=12,
        num_classes=10,
        mlp_ratio=4.0,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # Patch Embedding
        self.patch_embed = nn.Conv2d(
            in_channels, dim, kernel_size=patch_size, stride=patch_size
        )
        # 位置编码 (固定sin-cos)
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches, dim)
        )

        # 条件编码
        self.time_embed = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )
        self.class_embed = nn.Embedding(num_classes, dim)

        # Transformer Blocks
        self.blocks = nn.ModuleList([
            DiTBlock(dim, num_heads, mlp_ratio) for _ in range(depth)
        ])

        # 输出层
        self.final_norm = nn.LayerNorm(dim, elementwise_affine=False)
        self.final_adaLN = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 2 * dim),
        )
        self.final_proj = nn.Linear(dim, patch_size ** 2 * in_channels * 2)
        # 2倍channels用于预测噪声和对角方差

        self._init_weights()

    def _init_weights(self):
        # 初始化位置编码
        nn.init.normal_(self.pos_embed, std=0.02)
        # 零初始化输出层
        nn.init.zeros_(self.final_proj.weight)
        nn.init.zeros_(self.final_proj.bias)
        nn.init.zeros_(self.final_adaLN[-1].weight)
        nn.init.zeros_(self.final_adaLN[-1].bias)

    def unpatchify(self, x, H, W):
        """[B, N, p*p*C] → [B, C, H, W]"""
        p = self.patch_size
        c = x.shape[-1] // (p * p)
        h, w = H // p, W // p
        x = x.reshape(-1, h, w, p, p, c)
        x = x.permute(0, 5, 1, 3, 2, 4).reshape(-1, c, H, W)
        return x

    def forward(self, x, t, y):
        """
        x: [B, C, H, W] — 带噪声的潜变量
        t: [B] — 时间步
        y: [B] — 类别标签
        """
        B, C, H, W = x.shape

        # Patchify + 位置编码
        x = self.patch_embed(x).flatten(2).transpose(1, 2)  # [B, N, D]
        x = x + self.pos_embed

        # 条件编码
        cond = self.time_embed(t) + self.class_embed(y)     # [B, D]

        # Transformer Blocks
        for block in self.blocks:
            x = block(x, cond)

        # 输出投影
        gamma, beta = self.final_adaLN(cond).chunk(2, dim=-1)
        x = self.final_norm(x) * (1 + gamma.unsqueeze(1)) + beta.unsqueeze(1)
        x = self.final_proj(x)   # [B, N, p*p*2C]
        x = self.unpatchify(x, H, W)  # [B, 2C, H, W]

        noise_pred, var_pred = x.chunk(2, dim=1)
        return noise_pred, var_pred

# === 使用示例 ===
if __name__ == "__main__":
    model = DiT(
        img_size=32, patch_size=2, in_channels=4,
        dim=384, depth=12, num_heads=6, num_classes=10
    )
    print(f"DiT-S/2 参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

    x = torch.randn(2, 4, 32, 32)  # 潜空间输入
    t = torch.randint(0, 1000, (2,))
    y = torch.randint(0, 10, (2,))

    noise_pred, var_pred = model(x, t, y)
    print(f"输出形状: noise={noise_pred.shape}, var={var_pred.shape}")
    # 输出: noise=torch.Size([2, 4, 32, 32]), var=torch.Size([2, 4, 32, 32])

📋 面试要点

高频面试题

  1. DiT 相比 U-Net 的核心优势是什么?
  2. Transformer 架构具有清晰的 Scaling Laws ,模型越大效果持续提升
  3. 弱归纳偏置使其更灵活,易于多模态扩展
  4. 可复用 LLM 训练基础设施(并行策略、优化器等)
  5. U-Net 的下采样/上采样结构在极高分辨率下效率受限

  6. AdaLN-Zero 为什么效果最好?

  7. 零初始化使每个 Block 初始为恒等映射,类似 ResNet 的残差学习
  8. 训练初期模型=直接跳过所有层,梯度直接传到输入,训练更稳定
  9. 条件信息通过可学习的 scale/shift 注入,比简单拼接更灵活

  10. DiT 的 Patch Size 如何选择?

  11. \(p=2\) 最常用( DiT-XL/2 达到 FID 9.62 )
  12. 较小的 patch 保留更多空间细节但增加序列长度(计算量 \(O(N^2)\)
  13. 实际中通常在潜空间(如 8×下采样后的 32×32 )上使用 patch_size=2
  14. 对于高分辨率可考虑 patch_size=4 配合窗口注意力

  15. 如何理解扩散 Transformer 的 Scaling Laws ?

  16. FID 与 GFLOPs 呈幂律下降关系
  17. 这与 LLM 中 Loss 与参数量/数据量的幂律关系类似
  18. Sora 、 FLUX 等产品验证了大规模扩散 Transformer 的有效性
  19. 意味着扩散模型的天花板远未达到

✏️ 练习

练习 1 : DiT-S/2 训练

在 CIFAR-10 上训练简化版 DiT-S/2 ,对比不同条件注入方式( in-context / cross-attention / AdaLN-Zero )的收敛速度和最终 FID 。

练习 2 : Patch Size 消融实验

固定模型参数量,测试 patch_size=½/4 对生成质量和训练速度的影响。

练习 3 :结构变体实验

实现 U-ViT 的 long skip connection 并对比纯 DiT ,观察收敛速度的差异。

练习 4 :论文精读

  • 精读 DiT 原始论文,关注 Table 1-3 的 Scaling 实验
  • 阅读 U-ViT 论文,理解其与 DiT 在设计哲学上的差异

参考文献

  1. Peebles & Xie, 2023. "Scalable Diffusion Models with Transformers" — DiT 原始论文
  2. Bao et al., 2023. "All are Worth Words: A ViT Backbone for Diffusion Models" — U-ViT
  3. Ma et al., 2024. "Sit: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers" — SiT
  4. Gao et al., 2023. "Masked Diffusion Transformer is a Strong Image Synthesizer" — MDT
  5. Esser et al., 2024. "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" — MM-DiT/SD3
  6. Dosovitskiy et al., 2021. "An Image is Worth 16x16 Words" — ViT

下一章: 08-流匹配与一致性模型 — 探索更高效的扩散训练与采样范式