第 12 章 视觉 Transformer¶
📚 章节概述¶
本章深入讲解 Transformer 在计算机视觉中的关键应用。从 2020 年 ViT 横空出世,到近年的分类、分割、生成与多模态模型广泛采用 ViT 系路线,视觉建模范式发生了明显变化。本章将从 ViT 原始架构出发,涵盖代表性变体演进、高效设计、下游任务应用,以及工程实践中的关键代码实现。
学习时间: 7-10 天 难度等级:⭐⭐⭐⭐⭐ 前置知识:第 5-6 章 CNN 基础、 NLP 领域 Transformer 架构( Self-Attention / Multi-Head Attention / Positional Encoding )
🎯 学习目标¶
完成本章后,你将能够: - 深入理解 ViT 的每个组件( Patch Embedding / Position Embedding / CLS Token / MHA ) - 掌握 Swin Transformer 的层级设计与窗口注意力机制 - 了解 DeiT 、 BEiT 、 MAE 、 EVA 、 SigLIP 等重要变体的核心创新 - 比较 ViT 与 CNN 的归纳偏置、数据效率、计算复杂度差异 - 理解 ViT 在检测( DETR )、分割( SAM )、生成( DiT )等下游任务的应用 - 能够用 PyTorch 手写完整的 ViT 前向传播 - 熟练使用 HuggingFace 预训练 ViT 模型 - 掌握 8 组关键复盘问题
12.1 ViT 原始架构深度解析¶
12.1.1 论文背景¶
论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale( Dosovitskiy et al., ICLR 2021 )
核心思想:完全抛弃卷积操作,将图像视为一组 patch 序列,直接使用标准 Transformer 编码器进行分类。在 JFT-300M 等大规模数据预训练条件下, ViT 在当时的 ImageNet 评测中取得了很有竞争力的结果。
关键发现: - Transformer 在大规模数据下可以超越 CNN - 在中小数据集上 ViT 不如 CNN (缺乏归纳偏置) - Scaling Law 在视觉领域同样成立
12.1.2 Patch Embedding¶
Patch Embedding 是 ViT 将 2D 图像转换为 1D 序列的核心步骤。
原理: 1. 将 \(H \times W \times C\) 的图像切分为 \(N\) 个不重叠的 patch ,每个 patch 大小为 \(P \times P\) 2. patch 数量 \(N = \frac{H \times W}{P^2}\),例如 \(224 \times 224\) 图像、\(P=16\) → \(N = 196\) 3. 每个 patch 展平为 \(P^2 \cdot C\) 维向量(\(16 \times 16 \times 3 = 768\)) 4. 通过线性投影映射到 \(D\) 维嵌入空间
实现方式:实践中通常用一个 kernel_size=stride=P 的卷积层高效实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PatchEmbedding(nn.Module): # 继承nn.Module定义网络层
"""将图像分割为patches并投影到嵌入空间
用Conv2d实现,等价于:切patch → 展平 → 线性投影
但Conv2d利用了GPU并行计算,效率更高
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__() # super()调用父类方法
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2 # 196 for 224/16
# Conv2d with kernel_size=stride=patch_size 等价于切patch+线性投影
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W) e.g. (B, 3, 224, 224)
x = self.proj(x) # (B, embed_dim, H/P, W/P) e.g. (B, 768, 14, 14)
x = x.flatten(2) # (B, embed_dim, N) e.g. (B, 768, 196)
x = x.transpose(1, 2) # (B, N, embed_dim) e.g. (B, 196, 768)
return x
12.1.3 Position Embedding¶
Transformer 本身对输入顺序不敏感(排列等变性),因此必须额外注入位置信息。
ViT 使用可学习的 1D 位置编码:
| 位置编码类型 | 描述 | 代表模型 |
|---|---|---|
| 可学习 1D | 每个位置一个可训练向量,直接加到 patch embedding 上 | ViT |
| 正弦余弦 | 使用固定的 sin/cos 函数,不需要训练 | 原始 Transformer |
| 可学习 2D | 分别为行和列学习位置编码 | DeiT-III |
| 相对位置偏置 | 编码 patch 对之间的相对距离 | Swin Transformer |
| 旋转位置编码(RoPE) | 通过旋转矩阵注入位置 | EVA-02 |
| 无位置编码 | 利用卷积隐式编码位置 | CvT |
关键细节: ViT 论文实验表明,可学习 1D 与可学习 2D 位置编码效果几乎一致,因此默认采用更简单的 1D 方案。
# 可学习的1D位置编码
# +1是因为还有一个CLS token
pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
nn.init.trunc_normal_(pos_embed, std=0.02)
位置编码的可视化:训练后的位置编码会自动学习到 2D 空间结构——相邻 patch 的位置编码余弦相似度高,对角方向也呈现规律性。
12.1.4 CLS Token¶
设计动机:借鉴 BERT 的 [CLS] token 设计,在 patch 序列前添加一个可学习的分类 token 。
工作流程: 1. 初始化一个可学习的向量 cls_token ∈ R^D 2. 拼接到 patch embedding 序列的最前面:[CLS, p1, p2, ..., pN] 3. 经过 Transformer 编码器后, CLS token 聚合了全局信息 4. 取 CLS token 的输出接分类头
替代方案:全局平均池化( Global Average Pooling , GAP ) - DeiT 、 BEiT 等后续工作发现 GAP 效果与 CLS token 相当 - GAP 无需额外参数,更简洁
# 方案1: CLS Token(ViT原始)
cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
output = transformer_output[:, 0] # 取CLS位置
# 方案2: Global Average Pooling(很多变体采用)
output = transformer_output[:, 1:].mean(dim=1) # 对所有patch取平均
12.1.5 Multi-Head Self-Attention (MHSA)¶
MHSA 是 Transformer 的核心计算模块。对于输入序列 \(X \in \mathbb{R}^{N \times D}\):
多头机制:将 \(D\) 维分为 \(h\) 个头,每个头维度 \(d_k = D/h\),并行计算后拼接。
计算复杂度分析: - Self-Attention :\(O(N^2 \cdot D)\),\(N=196\) 时计算量可控 - 但\(N\)随分辨率平方增长:\(448 \times 448 \rightarrow N=784\),复杂度增 4 倍
class MultiHeadSelfAttention(nn.Module):
"""多头自注意力机制
实现要点:
1. QKV投影用一个线性层高效计算
2. 缩放因子 sqrt(d_k) 防止softmax梯度消失
3. 支持attention dropout
"""
def __init__(self, embed_dim=768, num_heads=12, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5 # 1/sqrt(d_k)
# 一个线性层同时计算Q, K, V(效率更高)
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
# QKV投影并reshape为多头格式
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) # 重塑张量形状
# permute将QKV维度提到最前面,unbind(0)沿第0维拆分为3个张量,避免3次独立线性层节省计算
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
q, k, v = qkv.unbind(0) # 各 (B, heads, N, head_dim)
# Scaled Dot-Product Attention
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, heads, N, N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# 加权求和并重组
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
12.1.6 Transformer Block ( Encoder Layer )¶
ViT 使用 Pre-Norm 结构( LayerNorm 在 Attention/MLP 之前),区别于原始 Transformer 的 Post-Norm :
x → LayerNorm → MHSA → + → LayerNorm → MLP → +
↑________________________| ↑___________________|
residual residual
MLP:两层全连接 + GELU 激活,隐藏层维度通常为 4 倍 embed_dim 。
class MLP(nn.Module):
"""前馈神经网络(FFN)"""
def __init__(self, embed_dim=768, mlp_ratio=4.0, drop=0.):
super().__init__()
hidden_dim = int(embed_dim * mlp_ratio)
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class TransformerBlock(nn.Module):
"""Pre-Norm Transformer Block"""
def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0,
drop=0., attn_drop=0.):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_drop, drop)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim, mlp_ratio, drop)
def forward(self, x):
x = x + self.attn(self.norm1(x)) # 残差连接
x = x + self.mlp(self.norm2(x)) # 残差连接
return x
12.1.7 完整 ViT 模型¶
class VisionTransformer(nn.Module):
"""完整的Vision Transformer实现
ViT-Base: embed_dim=768, depth=12, num_heads=12 (86M params)
ViT-Large: embed_dim=1024, depth=24, num_heads=16 (307M params)
ViT-Huge: embed_dim=1280, depth=32, num_heads=16 (632M params)
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0,
drop_rate=0., attn_drop_rate=0.):
super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
# 1. Patch Embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
num_patches = self.patch_embed.n_patches
# 2. CLS Token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 3. Position Embedding (可学习的1D)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(drop_rate)
# 4. Transformer Encoder
self.blocks = nn.Sequential(*[
TransformerBlock(embed_dim, num_heads, mlp_ratio, drop_rate, attn_drop_rate)
for _ in range(depth)
])
# 5. Classification Head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
# 参数初始化
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_module_weights)
def _init_module_weights(self, m):
if isinstance(m, nn.Linear): # isinstance检查类型
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
B = x.shape[0]
# Step 1: Patch Embedding → (B, N, D)
x = self.patch_embed(x)
# Step 2: Prepend CLS Token → (B, N+1, D)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # torch.cat沿已有维度拼接张量
# Step 3: Add Position Embedding
x = x + self.pos_embed
x = self.pos_drop(x)
# Step 4: Transformer Encoder
x = self.blocks(x)
# Step 5: Classification
x = self.norm(x)
cls_output = x[:, 0] # 取CLS token输出
logits = self.head(cls_output)
return logits
# ViT模型变体配置
def vit_base_patch16_224(**kwargs): # *args接收任意位置参数,**kwargs接收任意关键字参数
return VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
def vit_large_patch16_224(**kwargs):
return VisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
def vit_huge_patch14_224(**kwargs):
return VisionTransformer(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
12.1.8 ViT 模型族参数对比¶
| 模型 | Patch Size | Embed Dim | Depth | Heads | Params | ImageNet Top-1 |
|---|---|---|---|---|---|---|
| ViT-S/16 | 16 | 384 | 12 | 6 | 22M | 79.9% |
| ViT-B/16 | 16 | 768 | 12 | 12 | 86M | 84.5% |
| ViT-L/16 | 16 | 1024 | 24 | 16 | 307M | 87.1% |
| ViT-H/14 | 14 | 1280 | 32 | 16 | 632M | 88.6% |
| ViT-G/14 | 14 | 1664 | 48 | 16 | 1843M | 90.5% |
注: Top-1 精度基于 JFT-300M 预训练 + ImageNet 微调
12.2 ViT 变体演进¶
12.2.1 DeiT — 数据高效的知识蒸馏¶
论文:Training data-efficient image transformers & distillation through attention( Touvron et al., ICML 2021 )
核心动机: ViT 需要 JFT-300M 级别数据才能超越 CNN , DeiT 证明了仅用 ImageNet-1K 也可以训练出强大的 ViT 。
关键创新: 1. 知识蒸馏 Token:在 CLS token 之外,额外添加一个 distillation token - 蒸馏 token 向教师模型( RegNetY-16GF )学习 - 最终预测 = CLS 预测 + 蒸馏 token 预测的加权平均 2. 强数据增强: RandAugment 、 Mixup 、 CutMix 、 Random Erasing 3. 正则化策略: Stochastic Depth 、 Repeated Augmentation
输入序列: [CLS_token, distill_token, patch_1, patch_2, ..., patch_N]
分类损失: CrossEntropy(CLS_output, label)
蒸馏损失: KL_Div(distill_output, teacher_output) 或 CrossEntropy(distill_output, teacher_hard_label)
DeiT-III( 2022 更新):引入 3-Augment 策略(灰度+翻转+SolarizeAdd ),比 DeiT 效果更好。
| 模型 | ImageNet Top-1 | 预训练数据 |
|---|---|---|
| DeiT-S | 79.8% | ImageNet-1K |
| DeiT-B | 81.8% | ImageNet-1K |
| DeiT-B (蒸馏) | 83.4% | ImageNet-1K |
| DeiT-III-L | 87.7% | ImageNet-21K |
12.2.2 BEiT — BERT 式视觉预训练¶
论文:BEiT: BERT Pre-Training of Image Transformers( Bao et al., ICLR 2022 )
核心思想:借鉴 BERT 的 Masked Language Modeling (MLM),提出 Masked Image Modeling (MIM)。
预训练流程: 1. 使用离散 VAE ( dVAE )将图像 patch 编码为离散 visual tokens 2. 随机 mask 约 40%的 patch 3. 让 ViT 根据未 mask 的 patch 预测被 mask 位置的 visual token
原始图像 → Patch序列 [p1, p2, ..., p196]
↓ 随机mask
掩码图像 → [p1, [MASK], p3, [MASK], p5, ...]
↓ ViT编码
预测被mask位置的visual token ID
BEiT v2( 2022 ):使用 VQ-KD (向量量化知识蒸馏)替代 dVAE ,获得更好的视觉词典。
BEiT-3( 2023 ):统一的多模态预训练,用 Multiway Transformer 同时处理图像、文本和图像-文本对。
12.2.3 MAE — 掩码自编码器¶
论文:Masked Autoencoders Are Scalable Vision Learners( He et al., CVPR 2022 )
核心创新:极高掩码率( 75%)+ 非对称编码器-解码器架构。
与 BEiT 的区别: - MAE 直接在像素空间重建(无需 visual tokenizer ) - 编码器只处理可见 patch ( 25%),极大节省计算 - 解码器轻量(仅用于预训练),下游任务只用编码器
原始图像 196个patch
↓ 随机mask 75%
可见patch (49个) → 【重型Encoder】 → 编码特征
↓ 加入mask tokens和位置编码
全部tokens (196个) → 【轻量Decoder】 → 重建像素
↓
MSE Loss (仅在被mask的位置计算)
为什么 75%掩码率有效: - 图像有大量空间冗余(相邻区域高度相关) - 高掩码率迫使模型学习高层语义,而非简单的插值 - 节省预训练时间(编码器只处理 25%的 token )
12.2.4 EVA — 大规模视觉基础模型¶
论文:EVA: Exploring the Limits of Masked Visual Representation Learning at Scale( Fang et al., CVPR 2023 )
关键创新: - 使用 CLIP 的视觉特征作为 MIM 的重建目标(而非像素或离散 token ) - 证明 MIM + CLIP 特征蒸馏是 scaling ViT 的有效方案
EVA 系列演进: | 模型 | 参数量 | ImageNet Top-1 | 核心创新 | |------|--------|---------------|---------| | EVA | 1.0B | 89.6% | CLIP 特征作为 MIM 目标 | | EVA-02 | 304M | 90.0% | RoPE 位置编码 + SwiGLU FFN | | EVA-CLIP | 5.0B | 强势开源 CLIP 系列之一 | 融合 EVA 与 CLIP 训练 |
12.2.5 SigLIP — 更好的视觉-语言对齐¶
论文:Sigmoid Loss for Language Image Pre-Training( Zhai et al., ICCV 2023 )
核心改进:用 Sigmoid loss 替代 CLIP 的 Softmax (InfoNCE) loss 。
Softmax (CLIP) vs Sigmoid (SigLIP): - CLIP :需要全局归一化,依赖大 batch size ( 32K+),分布式训练需要 all-gather 操作 - SigLIP :逐对独立计算 sigmoid ,无需全局归一化,对 batch size 不敏感,可更高效地分布式训练
SigLIP 在近年的多模态研究中被广泛采用为视觉编码器候选之一,但并非所有 VLM 都使用 SigLIP。
12.3 Swin Transformer 深入¶
12.3.1 设计动机¶
ViT 的全局自注意力存在两个问题: 1. 计算复杂度:\(O(N^2)\) 随分辨率平方增长,难以处理高分辨率图像 2. 缺乏多尺度特征:单一分辨率的 token 序列,不适合密集预测任务(检测/分割)
Swin Transformer 通过层级设计 + 窗口注意力解决这两个问题。
12.3.2 层级设计( Hierarchical Architecture )¶
Swin 采用类似 CNN 的 4 阶段层级结构,每阶段通过 Patch Merging 下采样:
Stage 1: H/4 × W/4, dim=C (56×56, 96维)
↓ Patch Merging (2×2 → 1, dim×2)
Stage 2: H/8 × W/8, dim=2C (28×28, 192维)
↓ Patch Merging
Stage 3: H/16 × W/16, dim=4C (14×14, 384维)
↓ Patch Merging
Stage 4: H/32 × W/32, dim=8C (7×7, 768维)
Patch Merging:将相邻 \(2 \times 2\) 个 patch 拼接后通过线性层降维,类似 CNN 的 stride-2 卷积。
12.3.3 窗口注意力( Window Attention )¶
将 feature map 划分为 \(M \times M\)(默认 7×7 )的不重叠窗口,在每个窗口内部做自注意力:
- 计算复杂度:\(O(M^2 \cdot N)\),其中 \(N\) 为总 token 数 → 线性复杂度
- 对比全局注意力 \(O(N^2)\),窗口注意力在高分辨率下优势巨大
12.3.4 移位窗口( Shifted Window )¶
问题:窗口注意力导致窗口之间没有信息交互。
解决方案:交替使用普通窗口和移位窗口: - 奇数层:常规窗口划分 - 偶数层:窗口向右下移动 \(M/2\) 个像素后重新划分
高效实现:通过 cyclic shift + mask 实现,避免实际创建更多窗口:
class WindowAttention(nn.Module):
"""Swin Transformer的窗口注意力(含相对位置偏置)"""
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.window_size = window_size # (Wh, Ww)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
# 相对位置偏置表
# (2*Wh-1) * (2*Ww-1) 个可能的相对位置
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2*window_size[0]-1) * (2*window_size[1]-1), num_heads)
)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
# 计算每个token对的相对位置索引
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # (2, Wh, Ww) # torch.stack沿新维度拼接张量
coords_flatten = torch.flatten(coords, 1) # (2, Wh*Ww)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # (2, N, N)
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # (N, N, 2)
relative_coords[:, :, 0] += window_size[0] - 1
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # (N, N)
self.register_buffer("relative_position_index", relative_position_index)
def forward(self, x, mask=None):
B_, N, C = x.shape # B_为窗口总数
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
# 添加相对位置偏置
# 用index从偏置表查找→view(N,N,heads)→permute为(heads,N,N)→unsqueeze(0)加batch维以广播加到注意力分数上
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(N, N, -1).permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0) # unsqueeze增加一个维度
# 移位窗口的mask
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N)
attn = attn + mask.unsqueeze(1).unsqueeze(0) # 被mask位置设为-inf
attn = attn.view(-1, self.num_heads, N, N)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
def window_partition(x, window_size):
"""将feature map划分为窗口"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows # (num_windows*B, window_size, window_size, C)
def window_reverse(windows, window_size, H, W):
"""将窗口还原为feature map"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
12.3.5 Swin Transformer 模型族¶
| 模型 | Embed Dim | Depths | Heads | Params | ImageNet Top-1 |
|---|---|---|---|---|---|
| Swin-T | 96 | [2,2,6,2] | [3,6,12,24] | 29M | 81.3% |
| Swin-S | 96 | [2,2,18,2] | [3,6,12,24] | 50M | 83.0% |
| Swin-B | 128 | [2,2,18,2] | [4,8,16,32] | 88M | 83.5% |
| Swin-L | 192 | [2,2,18,2] | [6,12,24,48] | 197M | 86.4% |
Swin v2( 2022 ):引入 log-CPB (对数连续位置偏置)和余弦注意力,支持更高分辨率和更大模型。
12.4 高效 ViT¶
12.4.1 EfficientViT (MIT)¶
论文:EfficientViT: Lightweight Multi-Scale Attention for High-Resolution Dense Prediction( Cai et al., ICCV 2023 )
核心创新: - 多尺度线性注意力:用线性注意力(\(O(N)\))替代标准 softmax 注意力(\(O(N^2)\)) - 级联分组注意力:将不同 head 分配给不同的特征分片,减少冗余计算 - 在分割任务( Cityscapes 、 ADE20K )上以较低延迟取得有竞争力的结果
| 模型 | Params | FLOPs | ImageNet Top-1 | GPU 延迟 |
|---|---|---|---|---|
| EfficientViT-B1 | 9.1M | 0.52G | 79.4% | 0.3ms |
| EfficientViT-B2 | 24.3M | 1.6G | 82.1% | 0.6ms |
| EfficientViT-B3 | 48.6M | 4.0G | 83.5% | 1.1ms |
12.4.2 MobileViT¶
论文:MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer( Mehta & Rastegari, ICLR 2022 )
设计理念:结合 MobileNet 的轻量 CNN 和 ViT 的全局注意力。
- 用 MobileNetV2 的 Inverted Residual Block 提取局部特征
- 在关键阶段插入 Transformer Block 捕捉全局依赖
- 参数量仅 5-6M ,适合移动端部署
12.4.3 TinyViT¶
论文:TinyViT: Fast Pretraining Distillation for Small Vision Transformers( Wu et al., ECCV 2022 )
方法:通过大模型蒸馏快速预训练小 ViT ( 5M-21M 参数),在 ImageNet 上达到与大模型相当的精度。
| 模型 | Params | ImageNet Top-1 | 用途 |
|---|---|---|---|
| TinyViT-5M | 5.4M | 79.1% | 移动端推理 |
| TinyViT-11M | 11M | 81.5% | 边缘设备 |
| TinyViT-21M | 21M | 83.2% | 轻量服务端 |
12.5 ViT 在下游任务中的应用¶
12.5.1 目标检测 — DETR¶
论文:End-to-End Object Detection with Transformers( Carion et al., ECCV 2020 )
革命性意义:去掉了 anchor 、 NMS 等手工设计组件,用 Transformer 实现端到端检测。
架构:
图像 → CNN Backbone → Transformer Encoder → Transformer Decoder → 预测框+类别
↑
Object Queries (可学习)
匹配策略:用匈牙利算法做预测框与 GT 的二分匹配( Set Prediction )。
演进: DETR → Deformable DETR → DINO → Co-DETR → RT-DETR (实时版)
12.5.2 图像分割 — SAM¶
论文:Segment Anything( Kirillov et al., ICCV 2023 )
SAM (Segment Anything Model) 是视觉基础模型的里程碑: - Image Encoder: ViT-H ( MAE 预训练),提取图像特征 - Prompt Encoder:编码点击、框、文本等提示 - Mask Decoder:轻量 Transformer 解码器,生成分割掩码
SAM 2( 2024 ):扩展到视频分割,引入 Memory Mechanism 追踪时序信息。
12.5.3 图像生成 — DiT¶
论文:Scalable Diffusion Models with Transformers( Peebles & Xie, ICCV 2023 )
DiT (Diffusion Transformer):用 ViT 替代 U-Net 作为扩散模型的去噪网络。
核心设计: - 输入:带噪声的 latent patch 序列 + 时间步 embedding + 类别 embedding - 用 AdaLN-Zero (自适应 LayerNorm )注入条件信息 - DiT 类时空 token 架构常被视为理解视频生成系统的重要参考,但不应直接等同于某个闭源系统的完整官方实现
Noisy Latent → Patchify → Transformer Blocks (with AdaLN) → Unpatchify → Denoised Latent
↑ timestep + class label
12.5.4 其他重要应用¶
| 任务 | 代表模型 | 核心用法 |
|---|---|---|
| 语义分割 | SegFormer, Mask2Former | ViT 作为 backbone + 分割解码器 |
| 深度估计 | DPT, Depth Anything | ViT 编码器 + 多尺度解码器 |
| 视频理解 | VideoMAE, TimeSformer | 时空 patch + 视频 Transformer |
| 点云处理 | Point-BERT, Point-MAE | 3D 点云 patch 化 + ViT |
| 医学影像 | MedViT, SAM-Med | 预训练 ViT + 领域微调 |
12.6 ViT vs CNN 对比分析¶
12.6.1 归纳偏置( Inductive Bias )¶
| 特性 | CNN | ViT |
|---|---|---|
| 局部性 | 卷积核限制感受野,天然捕捉局部模式 | 全局注意力,需要从数据中学习局部性 |
| 平移等变性 | 卷积权重共享保证平移等变 | 无此先验,依赖数据增强和位置编码 |
| 层级特征 | pooling 自然构建多尺度 | 原始 ViT 单尺度,需 Swin 等设计引入 |
| 训练效率 | 归纳偏置帮助小数据学习 | 小数据下表现差,大数据下优势显现 |
12.6.2 数据效率¶
结论:ViT 往往是一个更“通用”但也更“依赖数据和预训练”的学习器;在大数据或强预训练条件下,它常能学到与 CNN 互补、并在部分任务上更强的表示,但小数据场景并不总是占优。
12.6.3 计算复杂度¶
对于输入分辨率 \(H \times W\)、 patch 数 \(N = HW/P^2\):
| 操作 | 复杂度 | 说明 |
|---|---|---|
| ViT Self-Attention | \(O(N^2 \cdot D)\) | 随分辨率二次增长 |
| Swin Window Attention | \(O(M^2 \cdot N \cdot D)\) | 线性于 N |
| CNN 3×3 Conv | \(O(9 \cdot C^2 \cdot N)\) | 线性于 N |
| Linear Attention | \(O(N \cdot D^2)\) | 线性于 N |
12.6.4 实践建议¶
| 场景 | 推荐架构 | 原因 |
|---|---|---|
| 预训练数据充足 | ViT/EVA | 往往更容易体现 scaling 优势 |
| 目标检测/分割 | Swin/ViTDet | 多尺度特征+高分辨率 |
| 移动端部署 | MobileViT/EfficientViT | 轻量高效 |
| 中小数据集 | ConvNeXt/CNN+ViT 混合 | 归纳偏置有帮助 |
| 多模态应用 | ViT+SigLIP | 与语言模型对接方便 |
12.7 使用预训练 ViT 模型¶
12.7.1 HuggingFace 推理¶
import torch
from pathlib import Path
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
# 加载预训练模型和处理器
model_name = 'google/vit-base-patch16-224'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModelForImageClassification.from_pretrained(model_name).to(device)
processor = AutoImageProcessor.from_pretrained(model_name)
model.eval() # eval()评估模式
# 推理;该示例是单张图片分类教学骨架
image_path = Path('image.jpg')
if not image_path.exists():
raise FileNotFoundError(f"找不到图像文件: {image_path}")
image = Image.open(image_path).convert('RGB')
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad(): # 禁用梯度计算,节省内存
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item() # 将单元素张量转为Python数值
predicted_class = model.config.id2label[predicted_class_idx]
confidence = torch.softmax(logits, dim=-1).max().item()
print(f"预测类别: {predicted_class}, 置信度: {confidence:.4f}")
12.7.2 timm 库使用¶
import timm
import torch
from pathlib import Path
from PIL import Image
from timm.data import resolve_data_config, create_transform
# timm 覆盖了大量 ViT 变体;具体模型名会随安装版本变化
model_name = 'vit_base_patch16_224.augreg_in21k_ft_in1k'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model(model_name, pretrained=True).to(device)
model.eval()
# 自动获取模型对应的数据预处理
config = resolve_data_config(model.pretrained_cfg)
transform = create_transform(**config)
image_path = Path('image.jpg')
if not image_path.exists():
raise FileNotFoundError(f"找不到图像文件: {image_path}")
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.softmax(output, dim=-1)
top5_prob, top5_idx = probabilities.topk(5)
for i in range(5):
print(f"Top-{i+1}: class={top5_idx[0][i].item()}, prob={top5_prob[0][i].item():.4f}")
# 列出当前 timm 版本下可用的 ViT 模型
vit_models = timm.list_models('vit_*', pretrained=True)
print(f"可用ViT模型数量: {len(vit_models)}")
12.7.3 ViT 微调示例¶
import torch
import torch.nn as nn
import timm
# 加载预训练ViT并修改分类头
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
model = model.to(device)
# 下面示例默认“只训练分类头”;若数据更多,可改为全量微调并使用更小学习率
freeze_backbone = True
if freeze_backbone:
for name, param in model.named_parameters():
if 'head' not in name:
param.requires_grad = False
optimizer = torch.optim.AdamW(model.head.parameters(), lr=1e-3, weight_decay=0.05)
else:
optimizer = torch.optim.AdamW([
{'params': model.head.parameters(), 'lr': 1e-3},
{'params': [p for n, p in model.named_parameters() if 'head' not in n], 'lr': 1e-5},
], weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# 训练循环;假定 dataloader 已输出 resize/normalize 后的 224x224 张量
def train_one_epoch(model, dataloader, optimizer, criterion, device):
if len(dataloader) == 0:
raise ValueError("dataloader 为空,无法执行训练")
model.train()
total_loss, correct, total = 0, 0, 0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device) # 移至GPU/CPU
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播计算梯度
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() # 更新参数
total_loss += loss.item()
correct += (outputs.argmax(1) == labels).sum().item()
total += labels.size(0)
return total_loss / len(dataloader), correct / total
12.8 练习题¶
基础题¶
- 简答题:
- ViT 的 Patch Embedding 相当于什么操作?为什么用 Conv2d 实现?
Patch Embedding 相当于用 stride=patch_size 的 Conv2d 对图像做不重叠卷积,将每个 patch 投影为一个向量。用 Conv2d 实现而非先 reshape 再线性层,是因为 Conv2d 能直接完成“切块+投影”两步,硬件利用率更高、实现更简洁。
- 解释 Pre-Norm 和 Post-Norm 的区别, ViT 使用哪种?
Post-Norm在残差连接之后做 LayerNorm (原始 Transformer ),训练时梯度不稳定,需要学习率 warmup 。Pre-Norm在残差连接之前做 LayerNorm ,梯度可以直接通过残差路径流动,训练更稳定,适合深层网络。 ViT 使用 Pre-Norm 。
-
为什么 ViT 需要位置编码?可学习 1D 和 2D 位置编码效果有区别吗?
Transformer 的 Self-Attention 是置换不变的( permutation-invariant ),无法感知 token 的空间位置,因此必须添加位置编码。实验表明可学习 1D 和 2D 位置编码效果差别很小(约±0.1%),因为 1D 编码在训练中可以自动学到类似 2D 的结构信息。
-
计算题:
- 计算 ViT-B/16 在 224×224 输入时的 patch 数量、序列长度和 Self-Attention 的 FLOPs 。
Patch 数量 = \((224/16)^2 = 14 \times 14 = 196\);加上 CLS token 后序列长度 \(N = 196 + 1 = 197\)。 ViT-B 的隐藏维度 \(D = 768\)。单层 Self-Attention 的 FLOPs ≈ \(2 \times 4N^2D = 2 \times 4 \times 197^2 \times 768 \approx 238M\)(包含 Q/K/V 投影和注意力计算)。 ViT-B 有 12 层,总注意力 FLOPs ≈ \(12 \times 238M \approx 2.86G\)。
- 如果将输入分辨率改为 384×384 ,序列长度和计算量如何变化?
Patch 数量 = \((384/16)^2 = 24 \times 24 = 576\),序列长度 \(N = 577\)。相比 224 分辨率的 197 ,序列长度增加约 2.93 倍。由于 Self-Attention 计算量与 \(N^2\) 成正比,注意力 FLOPs 增加约 \((577/197)^2 \approx 8.58\) 倍。总模型 FLOPs 增加约 3 倍左右(因为 MLP 部分仅线性增长)。
进阶题¶
- 编程题:
- 从零实现一个完整的 ViT (包含 Patch Embedding 、 Position Embedding 、 MHA 、 MLP 、 CLS Token )。
-
使用 timm 加载预训练 Swin-T ,在 CIFAR-10 上微调。
-
分析题:
- 比较 MAE 和 BEiT 的预训练策略,各有什么优劣?
MAE随机掩码图像 75%的 patch ,编码器只处理可见 patch ,解码器重建原始像素。优势:预训练效率高( 3-4×加速)、无需额外组件、概念简洁。劣势:重建低层像素可能不够语义化。BEiT先用 dVAE 将图像编码为离散 token ,然后预测被掩码位置的 token ID 。优势:预测目标更语义化。劣势:需要预训练 dVAE tokenizer ,流程更复杂。
- 为什么 DiT 在部分扩散模型中越来越常见?
①Scaling 能力强: Transformer 更容易按层数和宽度扩展;②条件注入灵活: AdaLN-Zero 等机制便于加入 timestep 和 class/text 条件;③生态统一:与 ViT/LLM 组件更容易复用;④预训练兼容:可借用部分 ViT 预训练经验。也要注意, U-Net 在图像扩散中依然常见,是否替换取决于分辨率、算力预算与任务形态。
12.9 关键复盘¶
高频复盘题¶
Q1: ViT 相比 CNN 有什么优势和劣势?在什么场景下选择 ViT ?
参考答案: 优势: - 全局感受野:第一层就能看到所有 patch 之间的关系, CNN 需要很深才能获得大感受野 - 可扩展性强:参数量和数据量增加时性能持续提升( Scaling Law ) - 架构统一:与 NLP 的 Transformer 共享架构,便于多模态融合 - 预训练迁移好: MAE/CLIP 等大规模预训练后迁移效果出色
劣势: - 数据饥渴:小数据集上不如 CNN (缺乏 locality 、 translation equivariance 等归纳偏置) - 计算复杂度高: Self-Attention 的 O(N²)使高分辨率输入开销大 - 对分辨率敏感:位置编码与输入尺寸绑定,分辨率变化需要插值
选择建议:预训练数据充足时可优先考虑 ViT ;小数据用 CNN 或混合架构;移动端用 EfficientViT/MobileViT ;密集预测用 Swin/ViTDet 。
Q2: 请解释 Swin Transformer 的移位窗口机制及其作用
参考答案: Swin Transformer 在相邻层交替使用两种窗口配置: - 常规窗口( W-MSA ):将 feature map 均匀划分为 \(M \times M\) 窗口 - 移位窗口( SW-MSA ):将窗口向右下偏移 \(M/2\),跨越常规窗口边界
作用: 1. 允许相邻窗口之间交换信息,避免信息孤岛 2. 保持线性计算复杂度(只在窗口内做注意力) 3. 高效实现:用 cyclic shift + attention mask ,无需创建额外窗口
Q3: MAE 为什么要用 75%这么高的掩码率?
参考答案: 1. 图像的空间冗余高:相邻像素高度相关,低掩码率下模型可以轻松"插值"恢复,学不到高层语义 2. 高掩码率迫使模型理解语义:看到极少信息时必须"推理"内容是什么 3. 计算效率:编码器只处理 25%的 visible tokens ,预训练速度提升 3-4 倍 4. 与 NLP 对比: BERT 用 15%的掩码率,因为语言信息密度高,而图像信息密度低
Q4: DETR 如何实现端到端目标检测?相比传统检测器有什么优势?
参考答案: DETR 流程: 1. CNN backbone 提取特征 → Transformer Encoder 全局建模 2. \(N\) 个可学习的 Object Queries 经 Transformer Decoder 交叉注意力 3. 每个 query 直接输出(class, box)预测 4. 用匈牙利算法做预测与 GT 的二分匹配,计算集合级别损失
优势:去掉了 anchor 设计、 NMS 后处理、 FPN 等手工组件,简化流程 劣势:训练收敛慢( 500 epochs )、小目标检测弱 → Deformable DETR 解决了这些问题
Q5: DiT 为什么在部分大规模扩散模型中越来越常见?
参考答案: 1. Scaling 能力: DiT 基于 Transformer,通常更容易沿层数/宽度扩展 2. 条件注入灵活: AdaLN-Zero 等设计便于接入 timestep 和 class/text 条件 3. 与统一骨干兼容:在多模态或视频生成体系里更容易与 ViT/LLM 组件协同 4. 预训练迁移:可借鉴部分 ViT 预训练经验 5. 补充判断:U-Net 仍广泛用于图像扩散, DiT 更像是在“大模型化”路线里增长很快,而非单向替代
Q6: SigLIP 相比 CLIP 有什么改进?为什么在部分 VLM 中更常见?
参考答案: 改进:用 Sigmoid loss 替代 InfoNCE (Softmax) loss - InfoNCE 需要全 batch 的负样本做归一化 → 需要超大 batch size 和 all-gather - Sigmoid loss 对每个正/负样本对独立计算 → 不依赖 batch size ,分布式训练更高效
在部分 VLM 中更常见的原因: - 训练更稳定、效率更高 - 相同参数量下 zero-shot 性能更好 - 在一批多模态系统中被采用,但并非唯一选择
Q7: 如何将 ViT 的位置编码从 224×224 分辨率迁移到 384×384 ?
参考答案: 使用双线性插值调整位置编码维度: 1. 将 \(14 \times 14\) 的 2D 位置编码 reshape 2. 用 F.interpolate 双线性插值到 \(24 \times 24\) 3. 再展平回 1D 序列 4. CLS token 的位置编码保持不变
pos_embed_2d = pos_embed[:, 1:].reshape(1, 14, 14, D).permute(0, 3, 1, 2)
pos_embed_2d = F.interpolate(pos_embed_2d, size=(24, 24), mode='bicubic', align_corners=False) # F.xxx PyTorch函数式API
new_pos_embed = torch.cat([pos_embed[:, :1], pos_embed_2d.flatten(2).transpose(1, 2)], dim=1)
Q8: 比较 ViT 、 Swin Transformer 和 ConvNeXt 三个架构的设计哲学
参考答案:
| 特性 | ViT | Swin Transformer | ConvNeXt |
|---|---|---|---|
| 设计哲学 | 纯 Transformer ,最小视觉先验 | Transformer + CNN 层级设计 | 纯 CNN + Transformer 训练技巧 |
| 注意力范围 | 全局 | 局部窗口 | 局部( 7×7 深度可分离卷积) |
| 多尺度特征 | 无(单尺度) | 有( 4 阶段层级) | 有( 4 阶段层级) |
| 计算复杂度 | O(N²) | O(N) | O(N) |
| 更常见使用场景 | 基础模型预训练 | 密集预测任务 | 需要 CNN 生态兼容 |
| 代表应用 | CLIP/MAE/DiT | 检测/分割 backbone | 替代 ResNet |
12.10 关键论文列表¶
必读论文¶
| 年份 | 论文 | 核心贡献 |
|---|---|---|
| 2020 | ViT: An Image is Worth 16x16 Words | 开创性地将纯 Transformer 用于视觉 |
| 2021 | DeiT: Training data-efficient image transformers | 知识蒸馏+强数据增强训练 ViT |
| 2021 | Swin Transformer: Hierarchical Vision Transformer | 移位窗口+层级结构,适配密集预测 |
| 2021 | BEiT: BERT Pre-Training of Image Transformers | Masked Image Modeling 预训练 |
| 2022 | MAE: Masked Autoencoders Are Scalable Vision Learners | 75%掩码+像素重建的高效预训练 |
| 2022 | ConvNeXt: A ConvNet for the 2020s | 用 Transformer 技巧现代化 CNN |
扩展论文¶
| 年份 | 论文 | 核心贡献 |
|---|---|---|
| 2020 | DETR: End-to-End Object Detection with Transformers | Transformer 端到端检测 |
| 2023 | EVA-02: A Visual Representation for Neon Genesis | RoPE + SwiGLU 的强力 ViT |
| 2023 | SigLIP: Sigmoid Loss for Language Image Pre-Training | 更高效的视觉-语言对比学习 |
| 2023 | SAM: Segment Anything | ViT 驱动的视觉分割基础模型 |
| 2023 | DiT: Scalable Diffusion Models with Transformers | ViT 替代 U-Net 做扩散模型 |
| 2023 | DINOv2: Learning Robust Visual Features | 自监督 ViT 视觉基础模型 |
| 2024 | SAM 2: Segment Anything in Images and Videos | ViT+Memory 的视频分割 |
12.11 本章小结¶
核心知识点¶
- ViT 原始架构: Patch Embedding → CLS Token + Position Embedding → Transformer Encoder → Classification Head
- ViT 变体: DeiT (蒸馏)、 BEiT ( MIM )、 MAE (掩码自编码)、 EVA (大规模)、 SigLIP (高效对比学习)
- Swin Transformer:层级设计 + 窗口注意力 + 移位窗口 + 相对位置偏置
- 高效 ViT: EfficientViT 、 MobileViT 、 TinyViT——面向端侧部署
- 下游应用: DETR (检测)、 SAM (分割)、 DiT (生成)
- ViT vs CNN:归纳偏置、数据效率、计算复杂度的全面对比
下一步¶
下一章:13-多模态学习.md - 学习 CLIP 、 BLIP 、 LLaVA 等视觉-语言多模态模型
恭喜完成第 12 章! 🎉 视觉 Transformer 已成为现代 CV 中极具影响力的一条主线——从分类、检测、分割到生成,你会反复看到 ViT 及其变体。
⚠️ 核验说明(2026-04-03):已逐段复核 ViT / SigLIP / DiT 相关表述,并按官方文档将 Hugging Face 推理示例改为更稳定的
AutoImageProcessor/AutoModelForImageClassification接口;timm模型名仍可能随版本变化,请以本机安装结果为准。
最后更新日期: 2026-04-03