跳转至

10-Transformer 架构

学习时间: 约 8-10 小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 循环神经网络、注意力机制基础、线性代数 学习目标: 深入理解 Transformer 的数学原理,掌握自注意力、多头注意力、位置编码等核心组件,能从零实现 Transformer


🎯 学习目标

  1. 理解从 RNN 到 Transformer 的演进动机
  2. 掌握自注意力机制( Self-Attention )的完整数学推导
  3. 理解多头注意力( Multi-Head Attention )的设计思想
  4. 掌握各种位置编码方案(正弦、可学习、 RoPE )
  5. 深入理解 Transformer 编码器和解码器的完整结构
  6. 理解 Pre-Norm 与 Post-Norm 的区别
  7. 了解 Vision Transformer (ViT) 如何将 Transformer 应用于视觉
  8. 理解 BERT 和 GPT 的架构选择差异
  9. 能用 PyTorch 从零实现 Multi-Head Attention
  10. 了解注意力复杂度优化方向

目录


1. 从 RNN 到 Transformer

1.1 RNN 的根本限制

尽管 LSTM/GRU 在很大程度上缓解了梯度消失问题, RNN 家族仍有两个难以克服的局限:

  1. 无法并行:隐藏状态 \(\mathbf{h}_t\) 依赖 \(\mathbf{h}_{t-1}\),必须逐步计算,无法利用现代 GPU 的并行能力
  2. 长距离衰减:即使是 LSTM ,在序列长度超过几百步后,信息传递仍会逐步衰减

1.2 注意力的启示

Seq2Seq + Attention 证明了一件事:直接建立远距离连接比通过递推传递信息更高效。

那么一个自然的问题是:能否完全抛弃递推结构,只用注意力来建模序列?

1.3 Transformer 的诞生

2017 年, Vaswani 等人提出了 "Attention Is All You Need",回答是 Yes

Transformer 的核心创新: - 用自注意力( Self-Attention )取代循环结构,让序列中每个位置都能直接关注所有其他位置 - 用位置编码补偿抛弃递推后丢失的位置信息 - 完全可并行化,训练速度远超 RNN

这篇论文改变了整个 AI 领域的发展方向——BERT 、 GPT 、 ViT 、 Stable Diffusion 等几乎所有现代 AI 模型都基于 Transformer 。


2. 自注意力机制

自注意力机制

图注:自注意力机制——每个位置都能直接关注序列中所有其他位置

2.1 核心思想

给定一个序列 \((x_1, x_2, \ldots, x_n)\),自注意力让每个元素通过计算与其他所有元素的"相关性"来更新自己的表示。

直觉:在理解"The animal didn't cross the street because it was too tired"这句话时,自注意力机制可以让"it"直接关注到"animal",自动建立两者的关联。

2.2 Q/K/V 计算

对输入序列的每个位置 \(\mathbf{x}_i \in \mathbb{R}^{d_{\text{model}}}\),通过三个线性变换生成查询( Query )、键( Key )、值( Value ):

\[\mathbf{Q} = \mathbf{X} \mathbf{W}^Q, \quad \mathbf{K} = \mathbf{X} \mathbf{W}^K, \quad \mathbf{V} = \mathbf{X} \mathbf{W}^V\]

Q/K/V 计算

图注: Q/K/V 计算过程——输入通过三个不同的权重矩阵投影

其中 \(\mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}}\) 是输入矩阵,\(\mathbf{W}^Q, \mathbf{W}^K \in \mathbb{R}^{d_{\text{model}} \times d_k}\)\(\mathbf{W}^V \in \mathbb{R}^{d_{\text{model}} \times d_v}\)

直觉理解: - Query (查询):当前位置"想要找什么" - Key (键):每个位置"提供的索引标签" - Value (值):每个位置"实际包含的内容"

注意力就是用 Query 去匹配所有的 Key ,根据匹配度加权求和所有的 Value 。

2.3 缩放点积注意力( Scaled Dot-Product Attention )

缩放点积注意力

图注:缩放点积注意力计算流程——Q·K^T 计算注意力分数, softmax 归一化后加权求和 V

\[\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right) \mathbf{V}\]

分步拆解:

第一步:计算注意力分数矩阵

\[\mathbf{S} = \mathbf{Q}\mathbf{K}^T \in \mathbb{R}^{n \times n}\]

\(S_{ij}\) 表示位置 \(i\) 对位置 \(j\) 的关注程度(未归一化)。

第二步:缩放

\[\mathbf{S} = \frac{\mathbf{S}}{\sqrt{d_k}}\]

为什么要除以 \(\sqrt{d_k}\)\(d_k\) 较大时,\(\mathbf{q}^T \mathbf{k}\) 的方差约为 \(d_k\)(假设 Q 和 K 的各分量独立同分布,均值 0 方差 1 ),导致值很大, softmax 输出接近 one-hot (梯度极小)。除以 \(\sqrt{d_k}\) 将方差归一到 1 ,使 softmax 输出更平滑,梯度更健康。

第三步: Softmax 归一化

\[\mathbf{A} = \text{softmax}(\mathbf{S}) \in \mathbb{R}^{n \times n}\]

每行归一化为概率分布,\(A_{ij}\) 是位置 \(i\) 分配给位置 \(j\) 的注意力权重。

第四步:加权求和

\[\text{Output} = \mathbf{A} \mathbf{V} \in \mathbb{R}^{n \times d_v}\]

每个位置的输出是所有位置 Value 的加权平均,权重就是注意力分数。

2.4 自注意力 vs 交叉注意力

  • 自注意力( Self-Attention ): Q 、 K 、 V 来自同一个序列,序列内部互看
  • 交叉注意力( Cross-Attention ): Q 来自一个序列(解码器), K 和 V 来自另一个序列(编码器),实现序列间信息交换

3. 多头注意力

多头注意力机制

图注:多头注意力机制——多个注意力头并行计算,结果拼接

3.1 动机

单头注意力只能关注一种模式的关联。多头注意力让模型能同时从不同的表示子空间中捕获不同类型的关系

例如处理"I love the cat because it is cute"时: - 第 1 个头可能关注语法关系("it" → "cat") - 第 2 个头可能关注语义关系("love" → "cute") - 第 3 个头可能关注局部上下文

3.2 数学公式

\[\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \mathbf{W}^O\]
\[\text{head}_i = \text{Attention}(\mathbf{Q}\mathbf{W}_i^Q, \mathbf{K}\mathbf{W}_i^K, \mathbf{V}\mathbf{W}_i^V)\]

其中: - \(h\) 是头的数量(原论文 \(h=8\)) - \(d_k = d_v = d_{\text{model}} / h\)(每个头的维度) - \(\mathbf{W}_i^Q, \mathbf{W}_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}\) - \(\mathbf{W}_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}\) - \(\mathbf{W}^O \in \mathbb{R}^{h \cdot d_v \times d_{\text{model}}}\) — 输出投影矩阵

3.3 计算效率

虽然有 \(h\) 个头,但每个头的维度是 \(d_{\text{model}} / h\),所以总计算量与单头全维度注意力相当:

\[\text{单头} : d_{\text{model}}^2 \quad \text{vs} \quad \text{多头} : h \times (d_{\text{model}} / h)^2 \times 3 + d_{\text{model}}^2 \approx d_{\text{model}}^2\]

实际实现中,多头注意力通常通过一次大矩阵乘法+reshape 来并行计算所有头,非常高效。


4. 位置编码

4.1 为什么需要位置编码

自注意力是集合运算( set operation ),对输入的排列( permutation )不变:打乱输入顺序,输出只是对应打乱,不改变值。

但语言是有序的——"dog bites man" 和 "man bites dog" 含义完全不同。因此需要显式注入位置信息。

4.2 正弦/余弦位置编码( Sinusoidal Positional Encoding )

正弦位置编码

图注:正弦位置编码——不同维度使用不同频率的正弦/余弦波

原始 Transformer 论文使用的固定位置编码:

\[PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i / d_{\text{model}}}}\right)\]
\[PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i / d_{\text{model}}}}\right)\]

其中 \(pos\) 是位置索引( 0, 1, 2, ...),\(i\) 是维度索引。

设计思想: - 不同维度使用不同频率的正弦/余弦波(从高频到低频) - 每个位置有唯一的编码 - 相对位置关系可以通过线性变换表示:\(PE_{pos+k}\) 可以表示为 \(PE_{pos}\) 的线性函数 - 可外推到训练中未见过的序列长度

Python
import math

class SinusoidalPositionalEncoding(nn.Module):  # 继承nn.Module定义神经网络层
    def __init__(self, d_model, max_len=5000, dropout=0.1):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # 链式调用,连续执行多个方法  # unsqueeze增加一个维度
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)

        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

4.3 可学习位置编码( Learned Positional Encoding )

直接将位置编码作为可训练的 embedding :

Python
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0.1):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        x = x + self.pos_embedding(positions)
        return self.dropout(x)

BERT 和 GPT 使用可学习位置编码,效果与正弦编码相当。缺点是不能外推到比训练时更长的序列。

4.4 旋转位置编码( RoPE, Rotary Position Embedding )

RoPE 旋转位置编码

图注: RoPE 旋转位置编码——通过旋转 Q 和 K 向量来编码相对位置信息

RoPE ( Su et al., 2021 )是目前大模型最流行的位置编码方案( LLaMA 、 Qwen 、 Mistral 等均采用)。

核心思想:不是将位置信息加到输入上,而是通过旋转Q 和 K 向量来编码位置信息。

对于 query 向量 \(\mathbf{q}\) 在位置 \(m\) 和 key 向量 \(\mathbf{k}\) 在位置 \(n\), RoPE 确保它们的内积只依赖于相对位置 \(m-n\)

\[\langle R(\mathbf{q}, m), R(\mathbf{k}, n) \rangle = g(\mathbf{q}, \mathbf{k}, m-n)\]

具体实现是将向量每两个维度分为一组,在每组上施加二维旋转:

\[R(\mathbf{x}, m) = \begin{pmatrix} x_1 \cos m\theta_1 - x_2 \sin m\theta_1 \\ x_1 \sin m\theta_1 + x_2 \cos m\theta_1 \\ \vdots \\ x_{d-1} \cos m\theta_{d/2} - x_d \sin m\theta_{d/2} \\ x_{d-1} \sin m\theta_{d/2} + x_d \cos m\theta_{d/2} \end{pmatrix}\]

其中 \(\theta_i = 10000^{-2i/d}\)

RoPE 的优势: - 天然编码相对位置信息 - 对长度具有更好的外推能力(配合 NTK-aware scaling 等技术) - 无额外参数 - 与注意力计算无缝结合

4.5 位置编码对比

位置编码对比

图注:不同位置编码方案的比较——正弦编码、可学习编码和 RoPE 各有优劣


5. Transformer 编码器

编码器-解码器架构

图注: Transformer 编码器和解码器架构——交叉注意力连接两者

5.1 编码器层结构

每个编码器层( Encoder Layer )包含两个子层:

Text Only
Encoder Layer:
    输入 x
    → Multi-Head Self-Attention → Add & Norm → z
    → Feed-Forward Network → Add & Norm → 输出

5.2 Layer Normalization

\[\text{LayerNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta\]

其中 \(\mu\)\(\sigma^2\) 是在特征维度上计算的均值和方差(与 Batch Norm 不同——Batch Norm 在 batch 维度上计算)。

为什么用 LayerNorm 而不是 BatchNorm ? - 序列长度可变, batch 中不同样本的时间步对不齐 - LayerNorm 对每个样本独立归一化,不依赖 batch 统计量 - 在推理时行为一致(无 train/eval 区别)

5.3 残差连接

每个子层的输出都加上输入(跳跃连接),与 ResNet 思想相同:

\[\text{Output} = \text{LayerNorm}(\mathbf{x} + \text{SubLayer}(\mathbf{x}))\]

残差连接确保梯度能直接回传,使训练深层 Transformer 成为可能。

5.4 前馈网络( FFN )

\[\text{FFN}(\mathbf{x}) = \text{GELU}(\mathbf{x} \mathbf{W}_1 + \mathbf{b}_1) \mathbf{W}_2 + \mathbf{b}_2\]

其中 \(\mathbf{W}_1 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}\)\(\mathbf{W}_2 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}}\),通常 \(d_{\text{ff}} = 4 \times d_{\text{model}}\)

原论文使用 ReLU ,现代模型更多使用 GELU 或 SwiGLU 。

FFN 的作用:对每个位置独立地进行非线性变换,增加模型的表达能力。如果说注意力层负责"位置间交互", FFN 就负责"位置内变换"。

5.5 完整编码器

编码器由 \(N\) 个相同的编码器层堆叠而成(原论文 \(N=6\)):

Python
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=512, nhead=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask=None, src_key_padding_mask=None):
        # Post-Norm: Norm在残差之后
        attn_out, _ = self.self_attn(x, x, x, attn_mask=src_mask,
                                      key_padding_mask=src_key_padding_mask)
        x = self.norm1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

6. Transformer 解码器

6.1 解码器层结构

每个解码器层包含三个子层:

Text Only
Decoder Layer:
    输入 y
    → Masked Multi-Head Self-Attention → Add & Norm
    → Multi-Head Cross-Attention (Q from decoder, K/V from encoder) → Add & Norm
    → Feed-Forward Network → Add & Norm → 输出

6.2 掩码自注意力( Masked Self-Attention )

解码器因果掩码

图注:解码器中的因果掩码——上三角区域被屏蔽,防止信息泄露

在训练时,解码器是并行处理整个目标序列的。但在推理时,生成是自回归的(一个一个 token 生成)。为了让训练和推理行为一致,需要让位置 \(i\) 只能看到位置 \(\le i\) 的信息——通过因果掩码( Causal Mask )实现。

\[\text{Mask}_{ij} = \begin{cases} 0 & \text{if } j \le i \\ -\infty & \text{if } j > i \end{cases}\]

在 softmax 之前加上掩码:

\[\text{Attention} = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}} + \text{Mask}\right) \mathbf{V}\]

\(-\infty\) 经过 softmax 后变为 0 ,实现了信息屏蔽。

Python
def generate_causal_mask(seq_len):
    """生成因果掩码:1=可注意, 0=屏蔽(与MHA的masked_fill(mask==0, -inf)配合使用)"""
    return torch.tril(torch.ones(seq_len, seq_len))

# 示例
mask = generate_causal_mask(5)
# tensor([[1., 0., 0., 0., 0.],
#         [1., 1., 0., 0., 0.],
#         [1., 1., 1., 0., 0.],
#         [1., 1., 1., 1., 0.],
#         [1., 1., 1., 1., 1.]])

6.3 交叉注意力( Cross-Attention )

交叉注意力机制

图注:交叉注意力机制——Query 来自解码器, Key 和 Value 来自编码器

解码器的第二个注意力子层执行交叉注意力: - Query来自解码器上一子层的输出 - Key 和 Value来自编码器的输出

这让解码器的每一步都能关注输入序列的所有位置,实现信息从编码器到解码器的传递。


7. 完整 Transformer 架构

Transformer 架构

Transformer 架构

图注: Transformer 完整架构——编码器-解码器结构

7.1 整体结构

Text Only
         输入序列                    目标序列(右移一位)
            ↓                              ↓
      输入嵌入 + 位置编码          输出嵌入 + 位置编码
            ↓                              ↓
     ┌──────────────┐              ┌──────────────┐
     │  Encoder ×N  │              │  Decoder ×N  │
     │              │───K,V──────→ │              │
     └──────────────┘              └──────────────┘
                                    Linear + Softmax
                                    输出概率分布

7.2 原始 Transformer 超参数

超参数
\(d_{\text{model}}\) 512
\(d_{\text{ff}}\) 2048
\(h\)(注意力头数) 8
\(d_k = d_v = d_{\text{model}} / h\) 64
编码器层数 \(N\) 6
解码器层数 \(N\) 6
Dropout 0.1
总参数量 ~65M

7.3 训练技巧

Warm-up 学习率调度(原论文使用):

\[\text{lr} = d_{\text{model}}^{-0.5} \cdot \min(\text{step}^{-0.5}, \text{step} \cdot \text{warmup\_steps}^{-1.5})\]

先线性增加学习率,再按步数的逆平方根衰减。现代实践中常用 Cosine Annealing 或 WSD 调度器替代。

Label Smoothing:将目标分布中正确类别的概率从 1 平滑到 \(1 - \epsilon\)(如 0.9 ),将 \(\epsilon\) 均分给其他类别,防止模型过度自信。


8. Pre-Norm vs Post-Norm

8.1 Post-Norm (原始 Transformer )

\[\text{Output} = \text{LayerNorm}(\mathbf{x} + \text{SubLayer}(\mathbf{x}))\]

LayerNorm 在残差连接之后。需要精心的学习率 warmup 才能稳定训练。

8.2 Pre-Norm (现代主流)

\[\text{Output} = \mathbf{x} + \text{SubLayer}(\text{LayerNorm}(\mathbf{x}))\]

LayerNorm 在子层之前。

Pre-Norm 的优势: - 训练更稳定,不需要 warmup - 残差路径上没有归一化,梯度传播更顺畅 - 更容易训练深层模型

Pre-Norm 的劣势: - 最终表示可能不够"归一化",通常在最后加一个额外的 LayerNorm - 有研究表明同等深度下 Post-Norm 的表达能力略强

现代大模型几乎都使用 Pre-Norm( GPT 系列、 LLaMA 等)。

Python
class PreNormEncoderLayer(nn.Module):
    """Pre-Norm Transformer编码器层"""
    def __init__(self, d_model=512, nhead=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, src_mask=None):
        # Pre-Norm: 先归一化,再进入子层
        attn_out, _ = self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=src_mask)
        x = x + attn_out
        x = x + self.ffn(self.norm2(x))
        return x

9. Vision Transformer

ViT 架构

图注: Vision Transformer 架构——将图像分割为 patch 后送入 Transformer 编码器

9.1 ViT 核心思想

Vision Transformer ( Dosovitskiy et al., 2020 )将 Transformer 直接应用于图像分类,证明了 Transformer 不仅限于 NLP 。

关键步骤:将图像分割为固定大小的 patch ,将每个 patch 线性映射为一个向量( token ),然后送入标准 Transformer 编码器。

9.2 图像 → 序列

Patch Embedding

图注: Patch Embedding——将图像分割为固定大小的 patch 并线性映射为向量

对于 \(H \times W \times C\) 的图像,使用 \(P \times P\) 大小的 patch :

  1. 将图像分割为 \(N = \frac{H \times W}{P^2}\) 个 patch
  2. 每个 patch 展平为 \(P^2 \times C\) 维向量
  3. 通过线性投影映射到 \(d_{\text{model}}\)
  4. 在序列前面加一个可学习的 [CLS] token
  5. 加上位置编码
\[\mathbf{z}_0 = [\mathbf{x}_{\text{cls}}; \mathbf{x}_1\mathbf{E}; \mathbf{x}_2\mathbf{E}; \ldots; \mathbf{x}_N\mathbf{E}] + \mathbf{E}_{\text{pos}}\]

其中 \(\mathbf{E} \in \mathbb{R}^{(P^2 \cdot C) \times d_{\text{model}}}\) 是 patch 嵌入矩阵。

Python
class PatchEmbedding(nn.Module):
    """将图像转化为patch序列"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        # 用卷积实现patch分割+线性投影
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: (batch, 3, 224, 224)
        x = self.proj(x)          # (batch, embed_dim, 14, 14)
        x = x.flatten(2)          # (batch, embed_dim, 196)
        x = x.transpose(1, 2)    # (batch, 196, embed_dim)
        return x

9.3 ViT 的关键发现

  • 数据量小时, ViT 不如 CNN (缺少 CNN 的归纳偏置:局部性和平移不变性)
  • 数据量大时( JFT-300M 等), ViT 显著超越 CNN
  • DeiT 等后续工作通过知识蒸馏和数据增强,使 ViT 在 ImageNet 上也能超越 CNN

10. BERT 与 GPT 的架构选择

10.1 BERT——编码器架构

BERT ( Bidirectional Encoder Representations from Transformers ) 只使用 Transformer 编码器

  • 双向注意力:每个 token 可以看到整个序列(前后文)
  • 预训练任务
  • MLM ( Masked Language Modeling ):随机遮住 15%的 token ,模型预测被遮住的词
  • NSP ( Next Sentence Prediction ):判断两个句子是否是连续的
  • 适用场景:理解型任务(分类、 NER 、 QA 、句子相似度)
  • 限制:不能直接用于文本生成(因为是双向的)

10.2 GPT——解码器架构

GPT ( Generative Pre-trained Transformer ) 只使用 Transformer 解码器(去掉交叉注意力):

  • 因果注意力( Causal/Masked Attention ):每个 token 只能看到左边的 token
  • 预训练任务:自回归语言建模 — 预测下一个 token
  • 适用场景:生成型任务(文本生成、对话、代码生成)
  • 优势:统一的"预测下一个 token"范式,可通过 scaling law 持续提升

10.3 对比总结

特性 BERT (编码器) GPT (解码器)
注意力方向 双向(全局可见) 单向(因果掩码)
预训练目标 MLM + NSP 自回归 LM
擅长任务 理解(分类、抽取) 生成(续写、对话)
代表模型 BERT, RoBERTa, DeBERTa GPT-⅔/4, LLaMA, Qwen
发展趋势 逐渐被 decoder-only 取代 当前大模型主流

现代趋势: Decoder-only 架构( GPT 式)通过提示( Prompting )和上下文学习( In-Context Learning )也能很好地完成理解型任务,因此成为大模型的主流选择。


11. PyTorch 实现:从零实现 Multi-Head Attention

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """从零实现多头注意力机制"""

    def __init__(self, d_model=512, num_heads=8, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"  # assert断言,条件为False时抛出异常

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度

        # Q/K/V投影矩阵(可以合并为一个大矩阵提升效率)
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        # 输出投影
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)

    def split_heads(self, x, batch_size):
        """
        将(batch, seq_len, d_model)拆分为(batch, num_heads, seq_len, d_k)
        """
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        缩放点积注意力
        Q, K, V: (batch, num_heads, seq_len, d_k)
        mask: (batch, 1, 1, seq_len) 或 (1, 1, seq_len, seq_len)
        """
        # Step 1: 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        # scores: (batch, num_heads, seq_len_q, seq_len_k)

        # Step 2: 应用掩码(如有)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Step 3: Softmax归一化
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Step 4: 加权求和
        output = torch.matmul(attn_weights, V)
        # output: (batch, num_heads, seq_len_q, d_k)

        return output, attn_weights

    def forward(self, query, key, value, mask=None):
        """
        前向传播
        query, key, value: (batch, seq_len, d_model)
        mask: 可选掩码
        Returns: (batch, seq_len, d_model), attention_weights
        """
        batch_size = query.size(0)

        # 1. 线性投影
        Q = self.W_q(query)  # (batch, seq_len, d_model)
        K = self.W_k(key)
        V = self.W_v(value)

        # 2. 拆分为多头
        Q = self.split_heads(Q, batch_size)  # (batch, num_heads, seq_len, d_k)
        K = self.split_heads(K, batch_size)
        V = self.split_heads(V, batch_size)

        # 3. 缩放点积注意力
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # 4. 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)

        # 5. 输出投影
        output = self.W_o(attn_output)

        return output, attn_weights

# ==================== 测试 ====================
def test_multi_head_attention():
    d_model = 512
    num_heads = 8
    seq_len = 20
    batch_size = 4

    mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)

    # 自注意力测试
    x = torch.randn(batch_size, seq_len, d_model)
    output, weights = mha(x, x, x)
    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {weights.shape}")
    assert output.shape == (batch_size, seq_len, d_model)
    assert weights.shape == (batch_size, num_heads, seq_len, seq_len)

    # 验证注意力权重行和为1
    row_sums = weights.sum(dim=-1)
    assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-6)
    print("All tests passed!")

    # 因果掩码测试
    causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
    output_causal, weights_causal = mha(x, x, x, mask=causal_mask)
    # 验证上三角的注意力权重为0
    upper_attn = weights_causal[:, :, 0, 1:]  # 第0个位置不应关注后面的位置
    assert torch.allclose(upper_attn, torch.zeros_like(upper_attn), atol=1e-6)
    print("Causal mask test passed!")

# test_multi_head_attention()

12. 复杂度分析与优化

12.1 标准注意力的复杂度

对于长度为 \(n\) 的序列,维度为 \(d\)

  • 时间复杂度\(O(n^2 \cdot d)\) — 主要来自 \(\mathbf{Q}\mathbf{K}^T\) 的计算
  • 空间复杂度\(O(n^2 + n \cdot d)\) — 需要存储 \(n \times n\) 的注意力矩阵

\(n\) 很大时(如处理长文档、高分辨率图像),\(n^2\) 成为严重瓶颈。

12.2 优化方向

稀疏注意力( Sparse Attention )

不计算完整的 \(n \times n\) 注意力矩阵,只计算其中一部分:

  • 固定模式:滑动窗口 + 全局 token ( Longformer, BigBird )
  • 学习模式:让模型学习关注哪些位置( Reformer 用 LSH 近似)

复杂度降至 \(O(n \cdot \sqrt{n})\)\(O(n \cdot w)\)\(w\) 为窗口大小)。

线性注意力( Linear Attention )

利用核技巧将注意力重写为线性形式:

\[\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \frac{\phi(\mathbf{Q})(\phi(\mathbf{K})^T \mathbf{V})}{\phi(\mathbf{Q})(\phi(\mathbf{K})^T \mathbf{1})}\]

先计算 \(\phi(\mathbf{K})^T \mathbf{V} \in \mathbb{R}^{d \times d}\),再与 \(\phi(\mathbf{Q})\) 相乘,复杂度 \(O(n \cdot d^2)\)

代表方法: Linear Transformer, Performer (Random Feature Attention)。

Flash Attention

不改变数学公式,而是利用 GPU 内存层次结构优化注意力计算的IO 效率

  • 通过分块( tiling )将计算放到 GPU SRAM 中
  • 避免将完整的 \(n \times n\) 注意力矩阵写回 HBM
  • 训练速度提升 2-4 倍,内存使用降至 \(O(n)\)

Flash Attention 已成为现代大模型训练的标配( PyTorch 2.0+内置 torch.nn.functional.scaled_dot_product_attention 自动使用 Flash Attention )。

状态空间模型( SSM )

Mamba 等模型用结构化的状态空间方程替代注意力:

\[\mathbf{h}_t = \bar{\mathbf{A}} \mathbf{h}_{t-1} + \bar{\mathbf{B}} \mathbf{x}_t, \quad \mathbf{y}_t = \mathbf{C} \mathbf{h}_t\]

训练时可并行化(通过并行扫描),推理时为 \(O(1)\) 常量时间(类似 RNN ),是后 Transformer 时代的重要方向之一。

12.3 复杂度对比

方法 时间复杂度 空间复杂度 特点
标准注意力 \(O(n^2 d)\) \(O(n^2)\) 精确,但不可扩展
稀疏注意力 \(O(n \cdot w \cdot d)\) \(O(n \cdot w)\) 窗口大小 \(w\) 限制
线性注意力 \(O(n d^2)\) \(O(nd)\) 近似,性能有损
Flash Attention \(O(n^2 d)\) \(O(n)\) 精确, IO 优化
SSM ( Mamba ) \(O(nd)\) \(O(nd)\) 线性,新范式

13. 练习与自我检查

✏️ 练习题

  1. 手动计算:给定 3 个二维向量 \(\mathbf{x}_1=[1,0], \mathbf{x}_2=[0,1], \mathbf{x}_3=[1,1]\)\(\mathbf{W}^Q=\mathbf{W}^K=\mathbf{W}^V=\mathbf{I}\)(单位矩阵),手动计算自注意力输出。

  2. 实现完整 Transformer:基于本章提供的 Multi-Head Attention ,补全 Transformer 编码器和解码器的完整实现。

  3. 位置编码可视化:实现正弦位置编码,可视化不同位置和不同维度的编码值(热力图),观察周期性模式。

  4. 因果掩码验证:实现一个 decoder-only Transformer ,验证因果掩码确实阻止了信息泄露——打乱位置 \(i\) 之后的 token 不应影响位置 \(i\) 的输出。

  5. 注意力可视化:训练一个小型 Transformer 进行机器翻译,可视化不同头的注意力矩阵,分析各头学到了什么模式。

  6. 效率对比:实现标准注意力和 Flash Attention (使用 PyTorch 2.0 的 SDPA ),在不同序列长度下对比速度和内存使用。

  7. ViT 实现:用 PyTorch 实现一个简化版 ViT ,在 CIFAR-10 上训练,对比与 ResNet 的效果。

面试要点

Q1: 解释 Transformer 中自注意力的计算过程? A: 输入经 Q/K/V 线性投影,计算 Q 和 K 的缩放点积得到注意力分数,经 softmax 归一化后加权求和 V 。公式:\(\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V\)。 Multi-Head 将 d_model 拆分为多个头,分别计算后拼接投影。

Q2: 为什么要除以 \(\sqrt{d_k}\)A: Q 和 K 的点积期望为 0 、方差为\(d_k\)(假设各分量独立标准正态)。不缩放时,\(d_k\)大会导致 softmax 饱和(梯度极小)。除以\(\sqrt{d_k}\)将方差归一到 1 。

Q3: Pre-Norm 和 Post-Norm 的区别? A: Post-Norm 在残差之后归一化( Layer→Residual→Norm ),训练不稳定需要 warmup 。 Pre-Norm 先归一化再进子层( Norm→Layer→Residual ),训练更稳定,是现代大模型的标准选择。

Q4: GPT 和 BERT 在架构上有什么区别? A: BERT 用 Transformer 编码器(双向注意力+MLM 预训练),适合理解任务。 GPT 用 Transformer 解码器(因果注意力+自回归 LM 预训练),适合生成任务。 GPT 架构已成为大模型主流。

Q5: Transformer 处理长序列有什么问题,如何优化? A: 标准注意力\(O(n^2)\)复杂度,长序列计算和内存开销巨大。优化方向:稀疏注意力( Longformer )、 Flash Attention ( IO 优化)、线性注意力(核近似)、 SSM ( Mamba ,线性复杂度)。

自我检查清单

  • 能完整写出缩放点积注意力公式和 Multi-Head Attention 公式
  • 理解 Q/K/V 的直觉含义和数学作用
  • 理解为什么要缩放(\(\sqrt{d_k}\)
  • 能解释正弦位置编码的设计思想,了解 RoPE 的优势
  • 理解编码器和解码器的结构差异(掩码、交叉注意力)
  • 能从零实现 Multi-Head Attention
  • 理解 Pre-Norm vs Post-Norm 的实际影响
  • 能说清 BERT 和 GPT 的架构选择和适用场景
  • 了解至少 3 种注意力优化方法
  • 理解 ViT 如何将图像转化为 token 序列

下一章: ../02-卷积神经网络/01-卷积神经网络基础 — 或继续进阶章节