07 - DiT 与 Transformer 扩散架构¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习时间: 4 小时 重要性: ⭐⭐⭐⭐⭐ Sora/SD3/FLUX 等最新模型的共同基础架构
🎯 学习目标¶
完成本章后,你将能够: - 理解扩散模型从 U-Net 到 Transformer 的范式转变动因 - 掌握 DiT 的核心设计: Patchify 、 AdaLN-Zero 条件注入 - 了解 U-ViT 、 SiT 、 MDT 、 MM-DiT 等 Transformer 扩散变体 - 理解 Scaling Laws 在扩散 Transformer 中的体现 - 实现一个简化版 DiT 模型
1. 从 U-Net 到 Transformer 的范式转变¶
1.1 为什么要替换 U-Net¶
U-Net 自 DDPM 以来一直是扩散模型的标准骨干网络,但它存在固有局限:
| 维度 | U-Net | Transformer |
|---|---|---|
| 归纳偏置 | 强空间局部性偏置 | 弱偏置,更灵活 |
| 可扩展性 | 难以单纯堆叠层数提升 | Scaling Laws 明确 |
| 多模态融合 | Cross-attention 拼接 | 天然支持序列混合 |
| 分辨率泛化 | 需要特殊处理 | 位置编码灵活适配 |
| 工程生态 | 定制化结构 | 复用 LLM 训练基础设施 |
关键动因:随着模型规模增大, Transformer 展现出更明确的 Scaling Law 特性——模型越大、数据越多、效果越好,且改进趋势可预测。这与 LLM 的成功经验一致。
1.2 Transformer 在视觉任务中的发展¶
ViT (2020) → 图像分类
↓
DALL-E (2021) → 自回归图像生成
↓
U-ViT (2023) → 将Transformer嵌入U-Net结构
↓
DiT (2023) → 纯Transformer扩散模型
↓
SD3/FLUX/Sora (2024) → 大规模Transformer扩散模型
2. DiT 原理详解¶
2.1 论文概述¶
DiT ( Peebles & Xie, 2023, "Scalable Diffusion Models with Transformers")首次证明了纯 Transformer 架构在扩散模型中可以取代 U-Net ,并展现出清晰的 Scaling Laws 。
2.2 Patchify :图像到序列的转换¶
DiT 遵循 ViT 的方式,将图像(或潜空间特征)分割为 patch 序列:
其中 \(N = \frac{H \times W}{p^2}\) 为 patch 数量,\(p\) 为 patch 大小,\(D\) 为 embedding 维度。
import torch
import torch.nn as nn
class PatchEmbed(nn.Module): # 继承nn.Module定义网络层
"""将潜空间特征图转换为patch序列"""
def __init__(self, img_size=32, patch_size=2, in_channels=4, embed_dim=768):
super().__init__() # super()调用父类方法
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: [B, C, H, W] → [B, D, H/p, W/p] → [B, N, D]
x = self.proj(x) # [B, D, H/p, W/p]
x = x.flatten(2).transpose(1, 2) # [B, N, D]
return x
Patch 大小的影响: - \(p=2\): 序列更长,计算量大,细节更好 - \(p=4\): 序列更短,效率高,适合高分辨率 - \(p=8\): 极端压缩,信息损失较大
2.3 AdaLN-Zero :条件注入机制¶
DiT 探索了四种将时间步 \(t\) 和类别标签 \(c\) 注入 Transformer Block 的方式:
| 方式 | 方法 | FID |
|---|---|---|
| In-context | 将 \(t, c\) 作为额外 token 拼接到序列中 | 较高 |
| Cross-attention | 将 \(t, c\) 作为 cross-attention 的 KV | 中等 |
| AdaLN | 用 \(t, c\) 预测 LayerNorm 的 \(\gamma, \beta\) | 较低 |
| AdaLN-Zero | AdaLN + 零初始化 scale 参数 | 最低 |
AdaLN-Zero 的核心思想:
标准 LayerNorm 对输入进行归一化后,施加可学习的 scale 和 shift : $\(\text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta\)$
AdaLN-Zero 将 \(\gamma, \beta\) 替换为条件预测的参数,并额外引入一个零初始化的门控参数 \(\alpha\):
初始化时 \(\alpha = 0\),使得每个 DiT Block 在训练初期是恒等映射,保证训练稳定性。
class AdaLNZero(nn.Module):
"""Adaptive Layer Norm Zero — DiT的核心条件注入模块"""
def __init__(self, dim, cond_dim):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
# 预测 gamma, beta, alpha 三组参数
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(cond_dim, 6 * dim) # 为attention和MLP各预测3个参数
)
# 零初始化
nn.init.zeros_(self.adaLN_modulation[-1].weight) # [-1]负索引取最后元素
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x, cond):
# cond: [B, cond_dim] → 预测6组调制参数
(gamma1, beta1, alpha1,
gamma2, beta2, alpha2) = self.adaLN_modulation(cond).chunk(6, dim=-1)
return gamma1, beta1, alpha1, gamma2, beta2, alpha2
def modulate(self, x, gamma, beta):
return self.norm(x) * (1 + gamma.unsqueeze(1)) + beta.unsqueeze(1) # unsqueeze增加一个维度
2.4 DiT 整体架构¶
输入潜变量 z ∈ R^(C×H×W)
↓
[PatchEmbed] → [B, N, D]
↓
+ Positional Embedding (sin-cos)
↓
┌─────────────────────────┐
│ DiT Block × L │
│ ┌─────────────────────┐ │
│ │ AdaLN-Zero(t, c) │ │
│ │ ↓ │ │
│ │ Self-Attention │ │
│ │ ↓ │ │
│ │ AdaLN-Zero(t, c) │ │
│ │ ↓ │ │
│ │ Pointwise FFN │ │
│ └─────────────────────┘ │
└─────────────────────────┘
↓
[Final AdaLN + Linear]
↓
[Unpatchify] → R^(2C×H×W) (预测噪声+对角协方差)
3. Transformer 扩散模型家族¶
3.1 U-ViT¶
U-ViT ( Bao et al., 2023, "All are Worth Words: A ViT Backbone for Diffusion Models")将 U-Net 的跳跃连接引入 ViT :
- 将时间步、条件、图像 patches 视为统一的 token 序列
- 浅层和深层之间引入 long skip connections
- 保持了 U-Net 跳跃连接的优势,同时获得 Transformer 的灵活性
3.2 SiT ( Scalable Interpolant Transformers )¶
SiT ( Ma et al., 2024 )结合 DiT 架构与 Flow Matching 训练:
- 使用随机插值( Stochastic Interpolant )框架统一多种扩散/流匹配目标
- 在 Interpolant 框架下自由选择不同的时间调度和速度场参数化
- 同等 FLOPs 下优于 DiT
3.3 MDT ( Masked Diffusion Transformer )¶
MDT ( Gao et al., 2023, "Masked Diffusion Transformer is a Strong Image Synthesizer")引入了掩码建模思想:
- 训练时随机 mask 部分 patch tokens
- 模型同时学习去噪和 patch 预测
- 加速训练收敛,提升上下文学习能力
3.4 MM-DiT¶
MM-DiT ( Esser et al., 2024 )是 SD3/SD3.5 使用的多模态 Transformer (详见上一章):
- 图像和文本 token 联合注意力
- 模态专属的 QKV 投影和 MLP
- 支持多种文本编码器的灵活接入
3.5 架构对比¶
| 模型 | 条件注入 | 文本处理 | 跳跃连接 | 训练目标 |
|---|---|---|---|---|
| DiT | AdaLN-Zero | 类别标签 | 无 | DDPM \(\epsilon\)-prediction |
| U-ViT | In-context | Token 拼接 | 有 | DDPM \(\epsilon\)-prediction |
| SiT | AdaLN-Zero | 类别标签 | 无 | Flow Matching |
| MDT | AdaLN-Zero | 类别标签 | 无 | DDPM + Mask |
| MM-DiT | AdaLN-Zero | 联合注意力 | 无 | Rectified Flow |
4. Scaling Laws 在扩散 Transformer 中的应用¶
4.1 DiT 的 Scaling 实验¶
DiT 论文中系统测试了不同规模的模型:
| 模型 | 层数 | 维度 | 注意力头数 | 参数量 | FID (ImageNet 256) |
|---|---|---|---|---|---|
| DiT-S/2 | 12 | 384 | 6 | 33M | 68.4 |
| DiT-B/2 | 12 | 768 | 12 | 130M | 43.5 |
| DiT-L/2 | 24 | 1024 | 16 | 458M | 23.3 |
| DiT-XL/2 | 28 | 1152 | 16 | 675M | 9.62 |
关键发现: 1. FID 与模型 FLOPs 呈幂律关系:\(\text{FID} \propto \text{GFLOPs}^{-\alpha}\) 2. 增大模型比增大 patch 更有效:相同 GFLOPs 下,大模型+大 patch < 大模型+小 patch 3. Scaling 趋势未饱和:暗示更大模型将持续改进
4.2 工业级 Scaling 实践¶
Sora 、 FLUX 等产品验证了扩散 Transformer 的 Scaling Law :
其中 \(N\) 为参数量,\(D\) 为数据量,\(L\) 为验证损失。
实际数据: - FLUX.1: 12B 参数 → 目前开源最佳图像质量 - Sora: 估计数十 B 参数 → 突破性视频质量 - 趋势:扩散模型正在重复 LLM 的 scaling 故事
5. 简化版 DiT 代码实现¶
import torch
import torch.nn as nn
import math
class SinusoidalPosEmb(nn.Module):
"""正弦时间步编码"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = t[:, None].float() * emb[None, :]
return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # torch.cat沿已有维度拼接张量
class DiTBlock(nn.Module):
"""DiT Transformer Block with AdaLN-Zero"""
def __init__(self, dim, num_heads, mlp_ratio=4.0):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# Self-Attention
self.qkv = nn.Linear(dim, 3 * dim)
self.proj = nn.Linear(dim, dim)
# FFN
mlp_hidden = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden),
nn.GELU(),
nn.Linear(mlp_hidden, dim),
)
# AdaLN-Zero: 两层各需 gamma, beta, alpha → 6 * dim
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.adaLN = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim),
)
# 关键:零初始化
nn.init.zeros_(self.adaLN[-1].weight)
nn.init.zeros_(self.adaLN[-1].bias)
def forward(self, x, cond):
"""
x: [B, N, D] — patch token序列
cond: [B, D] — 条件嵌入(时间步+类别)
"""
B, N, D = x.shape
# 预测6组调制参数
gamma1, beta1, alpha1, gamma2, beta2, alpha2 = \
self.adaLN(cond).chunk(6, dim=-1)
# --- Attention分支 ---
h = self.norm1(x) * (1 + gamma1.unsqueeze(1)) + beta1.unsqueeze(1)
qkv = self.qkv(h).reshape(B, N, 3, self.num_heads, D // self.num_heads) # 重塑张量形状
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
h = (attn @ v).transpose(1, 2).reshape(B, N, D)
h = self.proj(h)
x = x + alpha1.unsqueeze(1) * h # 零初始化门控
# --- FFN分支 ---
h = self.norm2(x) * (1 + gamma2.unsqueeze(1)) + beta2.unsqueeze(1)
h = self.mlp(h)
x = x + alpha2.unsqueeze(1) * h # 零初始化门控
return x
class DiT(nn.Module):
"""简化版 Diffusion Transformer"""
def __init__(
self,
img_size=32,
patch_size=2,
in_channels=4,
dim=768,
depth=12,
num_heads=12,
num_classes=10,
mlp_ratio=4.0,
):
super().__init__()
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Patch Embedding
self.patch_embed = nn.Conv2d(
in_channels, dim, kernel_size=patch_size, stride=patch_size
)
# 位置编码 (固定sin-cos)
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, dim)
)
# 条件编码
self.time_embed = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim),
)
self.class_embed = nn.Embedding(num_classes, dim)
# Transformer Blocks
self.blocks = nn.ModuleList([
DiTBlock(dim, num_heads, mlp_ratio) for _ in range(depth)
])
# 输出层
self.final_norm = nn.LayerNorm(dim, elementwise_affine=False)
self.final_adaLN = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 2 * dim),
)
self.final_proj = nn.Linear(dim, patch_size ** 2 * in_channels * 2)
# 2倍channels用于预测噪声和对角方差
self._init_weights()
def _init_weights(self):
# 初始化位置编码
nn.init.normal_(self.pos_embed, std=0.02)
# 零初始化输出层
nn.init.zeros_(self.final_proj.weight)
nn.init.zeros_(self.final_proj.bias)
nn.init.zeros_(self.final_adaLN[-1].weight)
nn.init.zeros_(self.final_adaLN[-1].bias)
def unpatchify(self, x, H, W):
"""[B, N, p*p*C] → [B, C, H, W]"""
p = self.patch_size
c = x.shape[-1] // (p * p)
h, w = H // p, W // p
x = x.reshape(-1, h, w, p, p, c)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(-1, c, H, W)
return x
def forward(self, x, t, y):
"""
x: [B, C, H, W] — 带噪声的潜变量
t: [B] — 时间步
y: [B] — 类别标签
"""
B, C, H, W = x.shape
# Patchify + 位置编码
x = self.patch_embed(x).flatten(2).transpose(1, 2) # [B, N, D]
x = x + self.pos_embed
# 条件编码
cond = self.time_embed(t) + self.class_embed(y) # [B, D]
# Transformer Blocks
for block in self.blocks:
x = block(x, cond)
# 输出投影
gamma, beta = self.final_adaLN(cond).chunk(2, dim=-1)
x = self.final_norm(x) * (1 + gamma.unsqueeze(1)) + beta.unsqueeze(1)
x = self.final_proj(x) # [B, N, p*p*2C]
x = self.unpatchify(x, H, W) # [B, 2C, H, W]
noise_pred, var_pred = x.chunk(2, dim=1)
return noise_pred, var_pred
# === 使用示例 ===
if __name__ == "__main__":
model = DiT(
img_size=32, patch_size=2, in_channels=4,
dim=384, depth=12, num_heads=6, num_classes=10
)
print(f"DiT-S/2 参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
x = torch.randn(2, 4, 32, 32) # 潜空间输入
t = torch.randint(0, 1000, (2,))
y = torch.randint(0, 10, (2,))
noise_pred, var_pred = model(x, t, y)
print(f"输出形状: noise={noise_pred.shape}, var={var_pred.shape}")
# 输出: noise=torch.Size([2, 4, 32, 32]), var=torch.Size([2, 4, 32, 32])
📋 面试要点¶
高频面试题¶
- DiT 相比 U-Net 的核心优势是什么?
- Transformer 架构具有清晰的 Scaling Laws ,模型越大效果持续提升
- 弱归纳偏置使其更灵活,易于多模态扩展
- 可复用 LLM 训练基础设施(并行策略、优化器等)
-
U-Net 的下采样/上采样结构在极高分辨率下效率受限
-
AdaLN-Zero 为什么效果最好?
- 零初始化使每个 Block 初始为恒等映射,类似 ResNet 的残差学习
- 训练初期模型=直接跳过所有层,梯度直接传到输入,训练更稳定
-
条件信息通过可学习的 scale/shift 注入,比简单拼接更灵活
-
DiT 的 Patch Size 如何选择?
- \(p=2\) 最常用( DiT-XL/2 达到 FID 9.62 )
- 较小的 patch 保留更多空间细节但增加序列长度(计算量 \(O(N^2)\))
- 实际中通常在潜空间(如 8×下采样后的 32×32 )上使用 patch_size=2
-
对于高分辨率可考虑 patch_size=4 配合窗口注意力
-
如何理解扩散 Transformer 的 Scaling Laws ?
- FID 与 GFLOPs 呈幂律下降关系
- 这与 LLM 中 Loss 与参数量/数据量的幂律关系类似
- Sora 、 FLUX 等产品验证了大规模扩散 Transformer 的有效性
- 意味着扩散模型的天花板远未达到
✏️ 练习¶
练习 1 : DiT-S/2 训练¶
在 CIFAR-10 上训练简化版 DiT-S/2 ,对比不同条件注入方式( in-context / cross-attention / AdaLN-Zero )的收敛速度和最终 FID 。
练习 2 : Patch Size 消融实验¶
固定模型参数量,测试 patch_size=½/4 对生成质量和训练速度的影响。
练习 3 :结构变体实验¶
实现 U-ViT 的 long skip connection 并对比纯 DiT ,观察收敛速度的差异。
练习 4 :论文精读¶
- 精读 DiT 原始论文,关注 Table 1-3 的 Scaling 实验
- 阅读 U-ViT 论文,理解其与 DiT 在设计哲学上的差异
参考文献¶
- Peebles & Xie, 2023. "Scalable Diffusion Models with Transformers" — DiT 原始论文
- Bao et al., 2023. "All are Worth Words: A ViT Backbone for Diffusion Models" — U-ViT
- Ma et al., 2024. "Sit: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers" — SiT
- Gao et al., 2023. "Masked Diffusion Transformer is a Strong Image Synthesizer" — MDT
- Esser et al., 2024. "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" — MM-DiT/SD3
- Dosovitskiy et al., 2021. "An Image is Worth 16x16 Words" — ViT
下一章: 08-流匹配与一致性模型 — 探索更高效的扩散训练与采样范式