跳转至

第 12 章 视觉 Transformer

📚 章节概述

本章深入讲解 Transformer 在计算机视觉中的关键应用。从 2020 年 ViT 横空出世,到近年的分类、分割、生成与多模态模型广泛采用 ViT 系路线,视觉建模范式发生了明显变化。本章将从 ViT 原始架构出发,涵盖代表性变体演进、高效设计、下游任务应用,以及工程实践中的关键代码实现。

学习时间: 7-10 天 难度等级:⭐⭐⭐⭐⭐ 前置知识:第 5-6 章 CNN 基础、 NLP 领域 Transformer 架构( Self-Attention / Multi-Head Attention / Positional Encoding )

🎯 学习目标

完成本章后,你将能够: - 深入理解 ViT 的每个组件( Patch Embedding / Position Embedding / CLS Token / MHA ) - 掌握 Swin Transformer 的层级设计与窗口注意力机制 - 了解 DeiT 、 BEiT 、 MAE 、 EVA 、 SigLIP 等重要变体的核心创新 - 比较 ViT 与 CNN 的归纳偏置、数据效率、计算复杂度差异 - 理解 ViT 在检测( DETR )、分割( SAM )、生成( DiT )等下游任务的应用 - 能够用 PyTorch 手写完整的 ViT 前向传播 - 熟练使用 HuggingFace 预训练 ViT 模型 - 掌握 8 组关键复盘问题


12.1 ViT 原始架构深度解析

12.1.1 论文背景

论文An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale( Dosovitskiy et al., ICLR 2021 )

核心思想:完全抛弃卷积操作,将图像视为一组 patch 序列,直接使用标准 Transformer 编码器进行分类。在 JFT-300M 等大规模数据预训练条件下, ViT 在当时的 ImageNet 评测中取得了很有竞争力的结果。

关键发现: - Transformer 在大规模数据下可以超越 CNN - 在中小数据集上 ViT 不如 CNN (缺乏归纳偏置) - Scaling Law 在视觉领域同样成立

12.1.2 Patch Embedding

Patch Embedding 是 ViT 将 2D 图像转换为 1D 序列的核心步骤。

原理: 1. 将 \(H \times W \times C\) 的图像切分为 \(N\) 个不重叠的 patch ,每个 patch 大小为 \(P \times P\) 2. patch 数量 \(N = \frac{H \times W}{P^2}\),例如 \(224 \times 224\) 图像、\(P=16\)\(N = 196\) 3. 每个 patch 展平为 \(P^2 \cdot C\) 维向量(\(16 \times 16 \times 3 = 768\)) 4. 通过线性投影映射到 \(D\) 维嵌入空间

实现方式:实践中通常用一个 kernel_size=stride=P 的卷积层高效实现:

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PatchEmbedding(nn.Module):  # 继承nn.Module定义网络层
    """将图像分割为patches并投影到嵌入空间

    用Conv2d实现,等价于:切patch → 展平 → 线性投影
    但Conv2d利用了GPU并行计算,效率更高
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()  # super()调用父类方法
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2  # 196 for 224/16

        # Conv2d with kernel_size=stride=patch_size 等价于切patch+线性投影
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W) e.g. (B, 3, 224, 224)
        x = self.proj(x)        # (B, embed_dim, H/P, W/P) e.g. (B, 768, 14, 14)
        x = x.flatten(2)        # (B, embed_dim, N) e.g. (B, 768, 196)
        x = x.transpose(1, 2)   # (B, N, embed_dim) e.g. (B, 196, 768)
        return x

12.1.3 Position Embedding

Transformer 本身对输入顺序不敏感(排列等变性),因此必须额外注入位置信息。

ViT 使用可学习的 1D 位置编码

位置编码类型 描述 代表模型
可学习 1D 每个位置一个可训练向量,直接加到 patch embedding 上 ViT
正弦余弦 使用固定的 sin/cos 函数,不需要训练 原始 Transformer
可学习 2D 分别为行和列学习位置编码 DeiT-III
相对位置偏置 编码 patch 对之间的相对距离 Swin Transformer
旋转位置编码(RoPE) 通过旋转矩阵注入位置 EVA-02
无位置编码 利用卷积隐式编码位置 CvT

关键细节: ViT 论文实验表明,可学习 1D 与可学习 2D 位置编码效果几乎一致,因此默认采用更简单的 1D 方案。

Python
# 可学习的1D位置编码
# +1是因为还有一个CLS token
pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
nn.init.trunc_normal_(pos_embed, std=0.02)

位置编码的可视化:训练后的位置编码会自动学习到 2D 空间结构——相邻 patch 的位置编码余弦相似度高,对角方向也呈现规律性。

12.1.4 CLS Token

设计动机:借鉴 BERT 的 [CLS] token 设计,在 patch 序列前添加一个可学习的分类 token 。

工作流程: 1. 初始化一个可学习的向量 cls_token ∈ R^D 2. 拼接到 patch embedding 序列的最前面:[CLS, p1, p2, ..., pN] 3. 经过 Transformer 编码器后, CLS token 聚合了全局信息 4. 取 CLS token 的输出接分类头

替代方案:全局平均池化( Global Average Pooling , GAP ) - DeiT 、 BEiT 等后续工作发现 GAP 效果与 CLS token 相当 - GAP 无需额外参数,更简洁

Python
# 方案1: CLS Token(ViT原始)
cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
output = transformer_output[:, 0]  # 取CLS位置

# 方案2: Global Average Pooling(很多变体采用)
output = transformer_output[:, 1:].mean(dim=1)  # 对所有patch取平均

12.1.5 Multi-Head Self-Attention (MHSA)

MHSA 是 Transformer 的核心计算模块。对于输入序列 \(X \in \mathbb{R}^{N \times D}\)

\[Q = XW_Q, \quad K = XW_K, \quad V = XW_V\]
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\]

多头机制:将 \(D\) 维分为 \(h\) 个头,每个头维度 \(d_k = D/h\),并行计算后拼接。

计算复杂度分析: - Self-Attention :\(O(N^2 \cdot D)\)\(N=196\) 时计算量可控 - 但\(N\)随分辨率平方增长:\(448 \times 448 \rightarrow N=784\),复杂度增 4 倍

Python
class MultiHeadSelfAttention(nn.Module):
    """多头自注意力机制

    实现要点:
    1. QKV投影用一个线性层高效计算
    2. 缩放因子 sqrt(d_k) 防止softmax梯度消失
    3. 支持attention dropout
    """
    def __init__(self, embed_dim=768, num_heads=12, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5  # 1/sqrt(d_k)

        # 一个线性层同时计算Q, K, V(效率更高)
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape

        # QKV投影并reshape为多头格式
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)  # 重塑张量形状
        # permute将QKV维度提到最前面,unbind(0)沿第0维拆分为3个张量,避免3次独立线性层节省计算
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv.unbind(0)  # 各 (B, heads, N, head_dim)

        # Scaled Dot-Product Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # 加权求和并重组
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

12.1.6 Transformer Block ( Encoder Layer )

ViT 使用 Pre-Norm 结构( LayerNorm 在 Attention/MLP 之前),区别于原始 Transformer 的 Post-Norm :

Text Only
x → LayerNorm → MHSA → + → LayerNorm → MLP → +
↑________________________|  ↑___________________|
       residual                  residual

MLP:两层全连接 + GELU 激活,隐藏层维度通常为 4 倍 embed_dim 。

Python
class MLP(nn.Module):
    """前馈神经网络(FFN)"""
    def __init__(self, embed_dim=768, mlp_ratio=4.0, drop=0.):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerBlock(nn.Module):
    """Pre-Norm Transformer Block"""
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0,
                 drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_drop, drop)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))   # 残差连接
        x = x + self.mlp(self.norm2(x))    # 残差连接
        return x

12.1.7 完整 ViT 模型

Python
class VisionTransformer(nn.Module):
    """完整的Vision Transformer实现

    ViT-Base:  embed_dim=768,  depth=12, num_heads=12  (86M params)
    ViT-Large: embed_dim=1024, depth=24, num_heads=16  (307M params)
    ViT-Huge:  embed_dim=1280, depth=32, num_heads=16  (632M params)
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0,
                 drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim

        # 1. Patch Embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.n_patches

        # 2. CLS Token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # 3. Position Embedding (可学习的1D)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(drop_rate)

        # 4. Transformer Encoder
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, mlp_ratio, drop_rate, attn_drop_rate)
            for _ in range(depth)
        ])

        # 5. Classification Head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        # 参数初始化
        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_module_weights)

    def _init_module_weights(self, m):
        if isinstance(m, nn.Linear):  # isinstance检查类型
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.shape[0]

        # Step 1: Patch Embedding → (B, N, D)
        x = self.patch_embed(x)

        # Step 2: Prepend CLS Token → (B, N+1, D)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # torch.cat沿已有维度拼接张量

        # Step 3: Add Position Embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Step 4: Transformer Encoder
        x = self.blocks(x)

        # Step 5: Classification
        x = self.norm(x)
        cls_output = x[:, 0]  # 取CLS token输出
        logits = self.head(cls_output)
        return logits

# ViT模型变体配置
def vit_base_patch16_224(**kwargs):  # *args接收任意位置参数,**kwargs接收任意关键字参数
    return VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)

def vit_large_patch16_224(**kwargs):
    return VisionTransformer(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)

def vit_huge_patch14_224(**kwargs):
    return VisionTransformer(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)

12.1.8 ViT 模型族参数对比

模型 Patch Size Embed Dim Depth Heads Params ImageNet Top-1
ViT-S/16 16 384 12 6 22M 79.9%
ViT-B/16 16 768 12 12 86M 84.5%
ViT-L/16 16 1024 24 16 307M 87.1%
ViT-H/14 14 1280 32 16 632M 88.6%
ViT-G/14 14 1664 48 16 1843M 90.5%

注: Top-1 精度基于 JFT-300M 预训练 + ImageNet 微调


12.2 ViT 变体演进

12.2.1 DeiT — 数据高效的知识蒸馏

论文Training data-efficient image transformers & distillation through attention( Touvron et al., ICML 2021 )

核心动机: ViT 需要 JFT-300M 级别数据才能超越 CNN , DeiT 证明了仅用 ImageNet-1K 也可以训练出强大的 ViT 。

关键创新: 1. 知识蒸馏 Token:在 CLS token 之外,额外添加一个 distillation token - 蒸馏 token 向教师模型( RegNetY-16GF )学习 - 最终预测 = CLS 预测 + 蒸馏 token 预测的加权平均 2. 强数据增强: RandAugment 、 Mixup 、 CutMix 、 Random Erasing 3. 正则化策略: Stochastic Depth 、 Repeated Augmentation

Text Only
输入序列: [CLS_token, distill_token, patch_1, patch_2, ..., patch_N]
分类损失: CrossEntropy(CLS_output, label)
蒸馏损失: KL_Div(distill_output, teacher_output) 或 CrossEntropy(distill_output, teacher_hard_label)

DeiT-III( 2022 更新):引入 3-Augment 策略(灰度+翻转+SolarizeAdd ),比 DeiT 效果更好。

模型 ImageNet Top-1 预训练数据
DeiT-S 79.8% ImageNet-1K
DeiT-B 81.8% ImageNet-1K
DeiT-B (蒸馏) 83.4% ImageNet-1K
DeiT-III-L 87.7% ImageNet-21K

12.2.2 BEiT — BERT 式视觉预训练

论文BEiT: BERT Pre-Training of Image Transformers( Bao et al., ICLR 2022 )

核心思想:借鉴 BERT 的 Masked Language Modeling (MLM),提出 Masked Image Modeling (MIM)。

预训练流程: 1. 使用离散 VAE ( dVAE )将图像 patch 编码为离散 visual tokens 2. 随机 mask 约 40%的 patch 3. 让 ViT 根据未 mask 的 patch 预测被 mask 位置的 visual token

Text Only
原始图像 → Patch序列 [p1, p2, ..., p196]
                           ↓ 随机mask
掩码图像 → [p1, [MASK], p3, [MASK], p5, ...]
                           ↓ ViT编码
                    预测被mask位置的visual token ID

BEiT v2( 2022 ):使用 VQ-KD (向量量化知识蒸馏)替代 dVAE ,获得更好的视觉词典。

BEiT-3( 2023 ):统一的多模态预训练,用 Multiway Transformer 同时处理图像、文本和图像-文本对。

12.2.3 MAE — 掩码自编码器

论文Masked Autoencoders Are Scalable Vision Learners( He et al., CVPR 2022 )

核心创新:极高掩码率( 75%)+ 非对称编码器-解码器架构。

与 BEiT 的区别: - MAE 直接在像素空间重建(无需 visual tokenizer ) - 编码器只处理可见 patch ( 25%),极大节省计算 - 解码器轻量(仅用于预训练),下游任务只用编码器

Text Only
原始图像 196个patch
   ↓ 随机mask 75%
可见patch (49个) → 【重型Encoder】 → 编码特征
   ↓ 加入mask tokens和位置编码
全部tokens (196个) → 【轻量Decoder】 → 重建像素
MSE Loss (仅在被mask的位置计算)

为什么 75%掩码率有效: - 图像有大量空间冗余(相邻区域高度相关) - 高掩码率迫使模型学习高层语义,而非简单的插值 - 节省预训练时间(编码器只处理 25%的 token )

12.2.4 EVA — 大规模视觉基础模型

论文EVA: Exploring the Limits of Masked Visual Representation Learning at Scale( Fang et al., CVPR 2023 )

关键创新: - 使用 CLIP 的视觉特征作为 MIM 的重建目标(而非像素或离散 token ) - 证明 MIM + CLIP 特征蒸馏是 scaling ViT 的有效方案

EVA 系列演进: | 模型 | 参数量 | ImageNet Top-1 | 核心创新 | |------|--------|---------------|---------| | EVA | 1.0B | 89.6% | CLIP 特征作为 MIM 目标 | | EVA-02 | 304M | 90.0% | RoPE 位置编码 + SwiGLU FFN | | EVA-CLIP | 5.0B | 强势开源 CLIP 系列之一 | 融合 EVA 与 CLIP 训练 |

12.2.5 SigLIP — 更好的视觉-语言对齐

论文Sigmoid Loss for Language Image Pre-Training( Zhai et al., ICCV 2023 )

核心改进:用 Sigmoid loss 替代 CLIP 的 Softmax (InfoNCE) loss 。

Softmax (CLIP) vs Sigmoid (SigLIP): - CLIP :需要全局归一化,依赖大 batch size ( 32K+),分布式训练需要 all-gather 操作 - SigLIP :逐对独立计算 sigmoid ,无需全局归一化,对 batch size 不敏感,可更高效地分布式训练

\[\text{CLIP Loss} = -\log \frac{\exp(\text{sim}(x_i, y_i)/\tau)}{\sum_j \exp(\text{sim}(x_i, y_j)/\tau)}\]
\[\text{SigLIP Loss} = -\frac{1}{N}\sum_{i,j} \log \sigma((-1)^{[i \neq j]} z_{ij}/\tau)\]

SigLIP 在近年的多模态研究中被广泛采用为视觉编码器候选之一,但并非所有 VLM 都使用 SigLIP。


12.3 Swin Transformer 深入

12.3.1 设计动机

ViT 的全局自注意力存在两个问题: 1. 计算复杂度\(O(N^2)\) 随分辨率平方增长,难以处理高分辨率图像 2. 缺乏多尺度特征:单一分辨率的 token 序列,不适合密集预测任务(检测/分割)

Swin Transformer 通过层级设计 + 窗口注意力解决这两个问题。

12.3.2 层级设计( Hierarchical Architecture )

Swin 采用类似 CNN 的 4 阶段层级结构,每阶段通过 Patch Merging 下采样:

Text Only
Stage 1: H/4 × W/4,  dim=C       (56×56, 96维)
    ↓ Patch Merging (2×2 → 1, dim×2)
Stage 2: H/8 × W/8,  dim=2C      (28×28, 192维)
    ↓ Patch Merging
Stage 3: H/16 × W/16, dim=4C     (14×14, 384维)
    ↓ Patch Merging
Stage 4: H/32 × W/32, dim=8C     (7×7, 768维)

Patch Merging:将相邻 \(2 \times 2\) 个 patch 拼接后通过线性层降维,类似 CNN 的 stride-2 卷积。

12.3.3 窗口注意力( Window Attention )

将 feature map 划分为 \(M \times M\)(默认 7×7 )的不重叠窗口,在每个窗口内部做自注意力:

  • 计算复杂度:\(O(M^2 \cdot N)\),其中 \(N\) 为总 token 数 → 线性复杂度
  • 对比全局注意力 \(O(N^2)\),窗口注意力在高分辨率下优势巨大

12.3.4 移位窗口( Shifted Window )

问题:窗口注意力导致窗口之间没有信息交互。

解决方案:交替使用普通窗口和移位窗口: - 奇数层:常规窗口划分 - 偶数层:窗口向右下移动 \(M/2\) 个像素后重新划分

高效实现:通过 cyclic shift + mask 实现,避免实际创建更多窗口:

Python
class WindowAttention(nn.Module):
    """Swin Transformer的窗口注意力(含相对位置偏置)"""
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (Wh, Ww)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

        # 相对位置偏置表
        # (2*Wh-1) * (2*Ww-1) 个可能的相对位置
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2*window_size[0]-1) * (2*window_size[1]-1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

        # 计算每个token对的相对位置索引
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))  # (2, Wh, Ww)  # torch.stack沿新维度拼接张量
        coords_flatten = torch.flatten(coords, 1)  # (2, Wh*Ww)

        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # (2, N, N)
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # (N, N, 2)
        relative_coords[:, :, 0] += window_size[0] - 1
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # (N, N)
        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self, x, mask=None):
        B_, N, C = x.shape  # B_为窗口总数
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale

        # 添加相对位置偏置
        # 用index从偏置表查找→view(N,N,heads)→permute为(heads,N,N)→unsqueeze(0)加batch维以广播加到注意力分数上
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(N, N, -1).permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)  # unsqueeze增加一个维度

        # 移位窗口的mask
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N)
            attn = attn + mask.unsqueeze(1).unsqueeze(0)  # 被mask位置设为-inf
            attn = attn.view(-1, self.num_heads, N, N)

        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return x

def window_partition(x, window_size):
    """将feature map划分为窗口"""
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows  # (num_windows*B, window_size, window_size, C)

def window_reverse(windows, window_size, H, W):
    """将窗口还原为feature map"""
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

12.3.5 Swin Transformer 模型族

模型 Embed Dim Depths Heads Params ImageNet Top-1
Swin-T 96 [2,2,6,2] [3,6,12,24] 29M 81.3%
Swin-S 96 [2,2,18,2] [3,6,12,24] 50M 83.0%
Swin-B 128 [2,2,18,2] [4,8,16,32] 88M 83.5%
Swin-L 192 [2,2,18,2] [6,12,24,48] 197M 86.4%

Swin v2( 2022 ):引入 log-CPB (对数连续位置偏置)和余弦注意力,支持更高分辨率和更大模型。


12.4 高效 ViT

12.4.1 EfficientViT (MIT)

论文EfficientViT: Lightweight Multi-Scale Attention for High-Resolution Dense Prediction( Cai et al., ICCV 2023 )

核心创新: - 多尺度线性注意力:用线性注意力(\(O(N)\))替代标准 softmax 注意力(\(O(N^2)\)) - 级联分组注意力:将不同 head 分配给不同的特征分片,减少冗余计算 - 在分割任务( Cityscapes 、 ADE20K )上以较低延迟取得有竞争力的结果

模型 Params FLOPs ImageNet Top-1 GPU 延迟
EfficientViT-B1 9.1M 0.52G 79.4% 0.3ms
EfficientViT-B2 24.3M 1.6G 82.1% 0.6ms
EfficientViT-B3 48.6M 4.0G 83.5% 1.1ms

12.4.2 MobileViT

论文MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer( Mehta & Rastegari, ICLR 2022 )

设计理念:结合 MobileNet 的轻量 CNN 和 ViT 的全局注意力。

  • 用 MobileNetV2 的 Inverted Residual Block 提取局部特征
  • 在关键阶段插入 Transformer Block 捕捉全局依赖
  • 参数量仅 5-6M ,适合移动端部署

12.4.3 TinyViT

论文TinyViT: Fast Pretraining Distillation for Small Vision Transformers( Wu et al., ECCV 2022 )

方法:通过大模型蒸馏快速预训练小 ViT ( 5M-21M 参数),在 ImageNet 上达到与大模型相当的精度。

模型 Params ImageNet Top-1 用途
TinyViT-5M 5.4M 79.1% 移动端推理
TinyViT-11M 11M 81.5% 边缘设备
TinyViT-21M 21M 83.2% 轻量服务端

12.5 ViT 在下游任务中的应用

12.5.1 目标检测 — DETR

论文End-to-End Object Detection with Transformers( Carion et al., ECCV 2020 )

革命性意义:去掉了 anchor 、 NMS 等手工设计组件,用 Transformer 实现端到端检测。

架构

Text Only
图像 → CNN Backbone → Transformer Encoder → Transformer Decoder → 预测框+类别
                                              Object Queries (可学习)

匹配策略:用匈牙利算法做预测框与 GT 的二分匹配( Set Prediction )。

演进: DETR → Deformable DETR → DINO → Co-DETR → RT-DETR (实时版)

12.5.2 图像分割 — SAM

论文Segment Anything( Kirillov et al., ICCV 2023 )

SAM (Segment Anything Model) 是视觉基础模型的里程碑: - Image Encoder: ViT-H ( MAE 预训练),提取图像特征 - Prompt Encoder:编码点击、框、文本等提示 - Mask Decoder:轻量 Transformer 解码器,生成分割掩码

SAM 2( 2024 ):扩展到视频分割,引入 Memory Mechanism 追踪时序信息。

12.5.3 图像生成 — DiT

论文Scalable Diffusion Models with Transformers( Peebles & Xie, ICCV 2023 )

DiT (Diffusion Transformer):用 ViT 替代 U-Net 作为扩散模型的去噪网络。

核心设计: - 输入:带噪声的 latent patch 序列 + 时间步 embedding + 类别 embedding - 用 AdaLN-Zero (自适应 LayerNorm )注入条件信息 - DiT 类时空 token 架构常被视为理解视频生成系统的重要参考,但不应直接等同于某个闭源系统的完整官方实现

Text Only
Noisy Latent → Patchify → Transformer Blocks (with AdaLN) → Unpatchify → Denoised Latent
                              ↑ timestep + class label

12.5.4 其他重要应用

任务 代表模型 核心用法
语义分割 SegFormer, Mask2Former ViT 作为 backbone + 分割解码器
深度估计 DPT, Depth Anything ViT 编码器 + 多尺度解码器
视频理解 VideoMAE, TimeSformer 时空 patch + 视频 Transformer
点云处理 Point-BERT, Point-MAE 3D 点云 patch 化 + ViT
医学影像 MedViT, SAM-Med 预训练 ViT + 领域微调

12.6 ViT vs CNN 对比分析

12.6.1 归纳偏置( Inductive Bias )

特性 CNN ViT
局部性 卷积核限制感受野,天然捕捉局部模式 全局注意力,需要从数据中学习局部性
平移等变性 卷积权重共享保证平移等变 无此先验,依赖数据增强和位置编码
层级特征 pooling 自然构建多尺度 原始 ViT 单尺度,需 Swin 等设计引入
训练效率 归纳偏置帮助小数据学习 小数据下表现差,大数据下优势显现

12.6.2 数据效率

Text Only
数据量          CNN (ResNet)    ViT
< 10K          ★★★★           ★★
10K - 1M       ★★★★           ★★★
1M - 10M       ★★★            ★★★★
> 100M         ★★★            ★★★★★

结论:ViT 往往是一个更“通用”但也更“依赖数据和预训练”的学习器;在大数据或强预训练条件下,它常能学到与 CNN 互补、并在部分任务上更强的表示,但小数据场景并不总是占优。

12.6.3 计算复杂度

对于输入分辨率 \(H \times W\)、 patch 数 \(N = HW/P^2\)

操作 复杂度 说明
ViT Self-Attention \(O(N^2 \cdot D)\) 随分辨率二次增长
Swin Window Attention \(O(M^2 \cdot N \cdot D)\) 线性于 N
CNN 3×3 Conv \(O(9 \cdot C^2 \cdot N)\) 线性于 N
Linear Attention \(O(N \cdot D^2)\) 线性于 N

12.6.4 实践建议

场景 推荐架构 原因
预训练数据充足 ViT/EVA 往往更容易体现 scaling 优势
目标检测/分割 Swin/ViTDet 多尺度特征+高分辨率
移动端部署 MobileViT/EfficientViT 轻量高效
中小数据集 ConvNeXt/CNN+ViT 混合 归纳偏置有帮助
多模态应用 ViT+SigLIP 与语言模型对接方便

12.7 使用预训练 ViT 模型

12.7.1 HuggingFace 推理

Python
import torch
from pathlib import Path
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image

# 加载预训练模型和处理器
model_name = 'google/vit-base-patch16-224'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModelForImageClassification.from_pretrained(model_name).to(device)
processor = AutoImageProcessor.from_pretrained(model_name)

model.eval()  # eval()评估模式

# 推理;该示例是单张图片分类教学骨架
image_path = Path('image.jpg')
if not image_path.exists():
    raise FileNotFoundError(f"找不到图像文件: {image_path}")

image = Image.open(image_path).convert('RGB')
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():  # 禁用梯度计算,节省内存
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()  # 将单元素张量转为Python数值
    predicted_class = model.config.id2label[predicted_class_idx]
    confidence = torch.softmax(logits, dim=-1).max().item()

print(f"预测类别: {predicted_class}, 置信度: {confidence:.4f}")

12.7.2 timm 库使用

Python
import timm
import torch
from pathlib import Path
from PIL import Image
from timm.data import resolve_data_config, create_transform

# timm 覆盖了大量 ViT 变体;具体模型名会随安装版本变化
model_name = 'vit_base_patch16_224.augreg_in21k_ft_in1k'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model(model_name, pretrained=True).to(device)
model.eval()

# 自动获取模型对应的数据预处理
config = resolve_data_config(model.pretrained_cfg)
transform = create_transform(**config)

image_path = Path('image.jpg')
if not image_path.exists():
    raise FileNotFoundError(f"找不到图像文件: {image_path}")

image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)

with torch.no_grad():
    output = model(input_tensor)
    probabilities = torch.softmax(output, dim=-1)
    top5_prob, top5_idx = probabilities.topk(5)

for i in range(5):
    print(f"Top-{i+1}: class={top5_idx[0][i].item()}, prob={top5_prob[0][i].item():.4f}")

# 列出当前 timm 版本下可用的 ViT 模型
vit_models = timm.list_models('vit_*', pretrained=True)
print(f"可用ViT模型数量: {len(vit_models)}")

12.7.3 ViT 微调示例

Python
import torch
import torch.nn as nn
import timm

# 加载预训练ViT并修改分类头
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
model = model.to(device)

# 下面示例默认“只训练分类头”;若数据更多,可改为全量微调并使用更小学习率
freeze_backbone = True
if freeze_backbone:
    for name, param in model.named_parameters():
        if 'head' not in name:
            param.requires_grad = False
    optimizer = torch.optim.AdamW(model.head.parameters(), lr=1e-3, weight_decay=0.05)
else:
    optimizer = torch.optim.AdamW([
        {'params': model.head.parameters(), 'lr': 1e-3},
        {'params': [p for n, p in model.named_parameters() if 'head' not in n], 'lr': 1e-5},
    ], weight_decay=0.05)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# 训练循环;假定 dataloader 已输出 resize/normalize 后的 224x224 张量
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    if len(dataloader) == 0:
        raise ValueError("dataloader 为空,无法执行训练")
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)  # 移至GPU/CPU

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()  # 清零梯度
        loss.backward()  # 反向传播计算梯度
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()  # 更新参数

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(dataloader), correct / total

12.8 练习题

基础题

  1. 简答题
  2. ViT 的 Patch Embedding 相当于什么操作?为什么用 Conv2d 实现?

    Patch Embedding 相当于用 stride=patch_size 的 Conv2d 对图像做不重叠卷积,将每个 patch 投影为一个向量。用 Conv2d 实现而非先 reshape 再线性层,是因为 Conv2d 能直接完成“切块+投影”两步,硬件利用率更高、实现更简洁。

  3. 解释 Pre-Norm 和 Post-Norm 的区别, ViT 使用哪种?

    Post-Norm在残差连接之后做 LayerNorm (原始 Transformer ),训练时梯度不稳定,需要学习率 warmup 。Pre-Norm在残差连接之前做 LayerNorm ,梯度可以直接通过残差路径流动,训练更稳定,适合深层网络。 ViT 使用 Pre-Norm 。

  4. 为什么 ViT 需要位置编码?可学习 1D 和 2D 位置编码效果有区别吗?

    Transformer 的 Self-Attention 是置换不变的( permutation-invariant ),无法感知 token 的空间位置,因此必须添加位置编码。实验表明可学习 1D 和 2D 位置编码效果差别很小(约±0.1%),因为 1D 编码在训练中可以自动学到类似 2D 的结构信息。

  5. 计算题

  6. 计算 ViT-B/16 在 224×224 输入时的 patch 数量、序列长度和 Self-Attention 的 FLOPs 。

    Patch 数量 = \((224/16)^2 = 14 \times 14 = 196\);加上 CLS token 后序列长度 \(N = 196 + 1 = 197\)。 ViT-B 的隐藏维度 \(D = 768\)。单层 Self-Attention 的 FLOPs ≈ \(2 \times 4N^2D = 2 \times 4 \times 197^2 \times 768 \approx 238M\)(包含 Q/K/V 投影和注意力计算)。 ViT-B 有 12 层,总注意力 FLOPs ≈ \(12 \times 238M \approx 2.86G\)

  7. 如果将输入分辨率改为 384×384 ,序列长度和计算量如何变化?

    Patch 数量 = \((384/16)^2 = 24 \times 24 = 576\),序列长度 \(N = 577\)。相比 224 分辨率的 197 ,序列长度增加约 2.93 倍。由于 Self-Attention 计算量与 \(N^2\) 成正比,注意力 FLOPs 增加约 \((577/197)^2 \approx 8.58\) 倍。总模型 FLOPs 增加约 3 倍左右(因为 MLP 部分仅线性增长)。

进阶题

  1. 编程题
  2. 从零实现一个完整的 ViT (包含 Patch Embedding 、 Position Embedding 、 MHA 、 MLP 、 CLS Token )。
  3. 使用 timm 加载预训练 Swin-T ,在 CIFAR-10 上微调。

  4. 分析题

  5. 比较 MAE 和 BEiT 的预训练策略,各有什么优劣?

    MAE随机掩码图像 75%的 patch ,编码器只处理可见 patch ,解码器重建原始像素。优势:预训练效率高( 3-4×加速)、无需额外组件、概念简洁。劣势:重建低层像素可能不够语义化。BEiT先用 dVAE 将图像编码为离散 token ,然后预测被掩码位置的 token ID 。优势:预测目标更语义化。劣势:需要预训练 dVAE tokenizer ,流程更复杂。

  6. 为什么 DiT 在部分扩散模型中越来越常见?

    Scaling 能力强: Transformer 更容易按层数和宽度扩展;②条件注入灵活: AdaLN-Zero 等机制便于加入 timestep 和 class/text 条件;③生态统一:与 ViT/LLM 组件更容易复用;④预训练兼容:可借用部分 ViT 预训练经验。也要注意, U-Net 在图像扩散中依然常见,是否替换取决于分辨率、算力预算与任务形态。


12.9 关键复盘

高频复盘题

Q1: ViT 相比 CNN 有什么优势和劣势?在什么场景下选择 ViT ?

参考答案优势: - 全局感受野:第一层就能看到所有 patch 之间的关系, CNN 需要很深才能获得大感受野 - 可扩展性强:参数量和数据量增加时性能持续提升( Scaling Law ) - 架构统一:与 NLP 的 Transformer 共享架构,便于多模态融合 - 预训练迁移好: MAE/CLIP 等大规模预训练后迁移效果出色

劣势: - 数据饥渴:小数据集上不如 CNN (缺乏 locality 、 translation equivariance 等归纳偏置) - 计算复杂度高: Self-Attention 的 O(N²)使高分辨率输入开销大 - 对分辨率敏感:位置编码与输入尺寸绑定,分辨率变化需要插值

选择建议:预训练数据充足时可优先考虑 ViT ;小数据用 CNN 或混合架构;移动端用 EfficientViT/MobileViT ;密集预测用 Swin/ViTDet 。


Q2: 请解释 Swin Transformer 的移位窗口机制及其作用

参考答案: Swin Transformer 在相邻层交替使用两种窗口配置: - 常规窗口( W-MSA ):将 feature map 均匀划分为 \(M \times M\) 窗口 - 移位窗口( SW-MSA ):将窗口向右下偏移 \(M/2\),跨越常规窗口边界

作用: 1. 允许相邻窗口之间交换信息,避免信息孤岛 2. 保持线性计算复杂度(只在窗口内做注意力) 3. 高效实现:用 cyclic shift + attention mask ,无需创建额外窗口


Q3: MAE 为什么要用 75%这么高的掩码率?

参考答案: 1. 图像的空间冗余高:相邻像素高度相关,低掩码率下模型可以轻松"插值"恢复,学不到高层语义 2. 高掩码率迫使模型理解语义:看到极少信息时必须"推理"内容是什么 3. 计算效率:编码器只处理 25%的 visible tokens ,预训练速度提升 3-4 倍 4. 与 NLP 对比: BERT 用 15%的掩码率,因为语言信息密度高,而图像信息密度低


Q4: DETR 如何实现端到端目标检测?相比传统检测器有什么优势?

参考答案DETR 流程: 1. CNN backbone 提取特征 → Transformer Encoder 全局建模 2. \(N\) 个可学习的 Object Queries 经 Transformer Decoder 交叉注意力 3. 每个 query 直接输出(class, box)预测 4. 用匈牙利算法做预测与 GT 的二分匹配,计算集合级别损失

优势:去掉了 anchor 设计、 NMS 后处理、 FPN 等手工组件,简化流程 劣势:训练收敛慢( 500 epochs )、小目标检测弱 → Deformable DETR 解决了这些问题


Q5: DiT 为什么在部分大规模扩散模型中越来越常见?

参考答案: 1. Scaling 能力: DiT 基于 Transformer,通常更容易沿层数/宽度扩展 2. 条件注入灵活: AdaLN-Zero 等设计便于接入 timestep 和 class/text 条件 3. 与统一骨干兼容:在多模态或视频生成体系里更容易与 ViT/LLM 组件协同 4. 预训练迁移:可借鉴部分 ViT 预训练经验 5. 补充判断:U-Net 仍广泛用于图像扩散, DiT 更像是在“大模型化”路线里增长很快,而非单向替代


Q6: SigLIP 相比 CLIP 有什么改进?为什么在部分 VLM 中更常见?

参考答案改进:用 Sigmoid loss 替代 InfoNCE (Softmax) loss - InfoNCE 需要全 batch 的负样本做归一化 → 需要超大 batch size 和 all-gather - Sigmoid loss 对每个正/负样本对独立计算 → 不依赖 batch size ,分布式训练更高效

在部分 VLM 中更常见的原因: - 训练更稳定、效率更高 - 相同参数量下 zero-shot 性能更好 - 在一批多模态系统中被采用,但并非唯一选择


Q7: 如何将 ViT 的位置编码从 224×224 分辨率迁移到 384×384 ?

参考答案: 使用双线性插值调整位置编码维度: 1. 将 \(14 \times 14\) 的 2D 位置编码 reshape 2. 用 F.interpolate 双线性插值到 \(24 \times 24\) 3. 再展平回 1D 序列 4. CLS token 的位置编码保持不变

Python
pos_embed_2d = pos_embed[:, 1:].reshape(1, 14, 14, D).permute(0, 3, 1, 2)
pos_embed_2d = F.interpolate(pos_embed_2d, size=(24, 24), mode='bicubic', align_corners=False)  # F.xxx PyTorch函数式API
new_pos_embed = torch.cat([pos_embed[:, :1], pos_embed_2d.flatten(2).transpose(1, 2)], dim=1)

Q8: 比较 ViT 、 Swin Transformer 和 ConvNeXt 三个架构的设计哲学

参考答案

特性 ViT Swin Transformer ConvNeXt
设计哲学 纯 Transformer ,最小视觉先验 Transformer + CNN 层级设计 纯 CNN + Transformer 训练技巧
注意力范围 全局 局部窗口 局部( 7×7 深度可分离卷积)
多尺度特征 无(单尺度) 有( 4 阶段层级) 有( 4 阶段层级)
计算复杂度 O(N²) O(N) O(N)
更常见使用场景 基础模型预训练 密集预测任务 需要 CNN 生态兼容
代表应用 CLIP/MAE/DiT 检测/分割 backbone 替代 ResNet

12.10 关键论文列表

必读论文

年份 论文 核心贡献
2020 ViT: An Image is Worth 16x16 Words 开创性地将纯 Transformer 用于视觉
2021 DeiT: Training data-efficient image transformers 知识蒸馏+强数据增强训练 ViT
2021 Swin Transformer: Hierarchical Vision Transformer 移位窗口+层级结构,适配密集预测
2021 BEiT: BERT Pre-Training of Image Transformers Masked Image Modeling 预训练
2022 MAE: Masked Autoencoders Are Scalable Vision Learners 75%掩码+像素重建的高效预训练
2022 ConvNeXt: A ConvNet for the 2020s 用 Transformer 技巧现代化 CNN

扩展论文

年份 论文 核心贡献
2020 DETR: End-to-End Object Detection with Transformers Transformer 端到端检测
2023 EVA-02: A Visual Representation for Neon Genesis RoPE + SwiGLU 的强力 ViT
2023 SigLIP: Sigmoid Loss for Language Image Pre-Training 更高效的视觉-语言对比学习
2023 SAM: Segment Anything ViT 驱动的视觉分割基础模型
2023 DiT: Scalable Diffusion Models with Transformers ViT 替代 U-Net 做扩散模型
2023 DINOv2: Learning Robust Visual Features 自监督 ViT 视觉基础模型
2024 SAM 2: Segment Anything in Images and Videos ViT+Memory 的视频分割

12.11 本章小结

核心知识点

  1. ViT 原始架构: Patch Embedding → CLS Token + Position Embedding → Transformer Encoder → Classification Head
  2. ViT 变体: DeiT (蒸馏)、 BEiT ( MIM )、 MAE (掩码自编码)、 EVA (大规模)、 SigLIP (高效对比学习)
  3. Swin Transformer:层级设计 + 窗口注意力 + 移位窗口 + 相对位置偏置
  4. 高效 ViT: EfficientViT 、 MobileViT 、 TinyViT——面向端侧部署
  5. 下游应用: DETR (检测)、 SAM (分割)、 DiT (生成)
  6. ViT vs CNN:归纳偏置、数据效率、计算复杂度的全面对比

下一步

下一章13-多模态学习.md - 学习 CLIP 、 BLIP 、 LLaVA 等视觉-语言多模态模型


恭喜完成第 12 章! 🎉 视觉 Transformer 已成为现代 CV 中极具影响力的一条主线——从分类、检测、分割到生成,你会反复看到 ViT 及其变体。

⚠️ 核验说明(2026-04-03):已逐段复核 ViT / SigLIP / DiT 相关表述,并按官方文档将 Hugging Face 推理示例改为更稳定的 AutoImageProcessor / AutoModelForImageClassification 接口;timm 模型名仍可能随版本变化,请以本机安装结果为准。


最后更新日期: 2026-04-03