跳转至

01 - Transformer 深入理解(全面版)

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习目标:从零开始深入理解 Transformer 的每一个组件,掌握其数学原理、实现细节和设计思想。

📌 定位说明:本章侧重大模型视角下的 Transformer 深入理解( RoPE/RMSNorm/GeLU/编解码器架构对比等)。 Transformer 架构的基础教学(从零实现完整 Transformer )请参考 深度学习/04-Transformer/02-Transformer 架构


目录

  1. Transformer 架构总览
  2. 输入嵌入层深度解析
  3. 位置编码机制
  4. 自注意力机制详解
  5. 多头注意力机制
  6. 前馈神经网络
  7. 层归一化与残差连接
  8. 编解码器架构对比
  9. Transformer 的训练
  10. Transformer 的推理

Transformer 架构总览

1.1 为什么需要 Transformer

在 Transformer 出现之前,序列建模主要依赖 RNN/LSTM :

Text Only
RNN的问题:
├── 顺序计算,无法并行
├── 长距离依赖困难(梯度消失/爆炸)
└── 计算复杂度与序列长度成正比

Transformer的解决方案:
├── 完全基于注意力机制
├── 完全并行计算
├── 任意位置间距离都是O(1)
└── 成为现代NLP的基础架构

1.2 架构全景图

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                     Transformer 架构全景                         │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    输入序列 (Input)                      │    │
│  │              [我, 喜欢, 深度, 学习]                       │    │
│  └─────────────────────────────────────────────────────────┘    │
│                              │                                   │
│                              ▼                                   │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │              输入嵌入 + 位置编码                          │    │
│  │         (Token Embedding + Positional Encoding)          │    │
│  └─────────────────────────────────────────────────────────┘    │
│                              │                                   │
│           ┌──────────────────┴──────────────────┐               │
│           ▼                                      ▼               │
│  ┌─────────────────┐                  ┌─────────────────┐       │
│  │    Encoder      │                  │    Decoder      │       │
│  │  (编码器堆叠)    │                  │  (解码器堆叠)    │       │
│  │                 │                  │                 │       │
│  │  ┌───────────┐  │                  │  ┌───────────┐  │       │
│  │  │ Block N   │  │                  │  │ Block N   │  │       │
│  │  │ ┌───────┐ │  │                  │  │ ┌───────┐ │  │       │
│  │  │ │Multi- │ │  │                  │  │ │Masked │ │  │       │
│  │  │ │Head   │ │  │                  │  │ │Multi-H│ │  │       │
│  │  │ │Attn   │ │  │                  │  │ │Attn   │ │  │       │
│  │  │ └───────┘ │  │                  │  │ └───────┘ │  │       │
│  │  │ ┌───────┐ │  │                  │  │ ┌───────┐ │  │       │
│  │  │ │ Feed  │ │  │                  │  │ │Multi-H│ │  │       │
│  │  │ │Forward│ │  │                  │  │ │Cross  │ │  │       │
│  │  │ └───────┘ │  │                  │  │ │Attn   │ │  │       │
│  │  └───────────┘  │                  │  │ └───────┘ │  │       │
│  │       ...       │                  │  │ ┌───────┐ │  │       │
│  │  ┌───────────┐  │                  │  │ │ Feed  │ │  │       │
│  │  │ Block 1   │  │                  │  │ │Forward│ │  │       │
│  │  └───────────┘  │                  │  │ └───────┘ │  │       │
│  └─────────────────┘                  │  └───────────┘  │       │
│           │                           └─────────────────┘       │
│           │                                  │                   │
│           └──────────────────┬───────────────┘                   │
│                              ▼                                   │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    输出层 (Output)                       │    │
│  │              Linear + Softmax → 概率分布                 │    │
│  └─────────────────────────────────────────────────────────┘    │
│                              │                                   │
│                              ▼                                   │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    预测输出                              │    │
│  │              [I, like, deep, learning]                   │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

Transformer 架构图

上图展示了 Transformer 的完整架构,包括编码器( Encoder )和解码器( Decoder )的堆叠结构,以及自注意力机制和前馈神经网络等核心组件。

1.3 核心组件一览

组件 功能 关键参数 复杂度
输入嵌入 将 token 映射为向量 vocab_size × d_model -
位置编码 注入位置信息 d_model -
多头注意力 捕捉不同子空间的关系 h heads, d_k = d_model/h O(n²·d)
前馈网络 非线性变换 d_model → 4d_model → d_model O(n·d²)
层归一化 稳定训练 d_model O(n·d)
残差连接 缓解梯度消失 - -

输入嵌入层深度解析

2.1 词嵌入的数学本质

Text Only
词嵌入层就是一个查找表(Lookup Table):

E ∈ ℝ^(V × d)

其中:
- V: 词汇表大小(如32000)
- d: 嵌入维度(如512, 768, 1024)

对于输入token id = i:
embedding = E[i, :] ∈ ℝ^d

这相当于一个one-hot向量与嵌入矩阵的乘法:
embedding = one_hot(i) @ E

2.2 嵌入层的实现细节

Python
import torch
import torch.nn as nn
import math

class TokenEmbedding(nn.Module):
    """
    Token嵌入层完整实现
    """
    def __init__(self, vocab_size, d_model, padding_idx=None):
        super().__init__()  # super()调用父类方法
        self.vocab_size = vocab_size
        self.d_model = d_model

        # 嵌入矩阵
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_model,
            padding_idx=padding_idx  # 用于mask填充位置
        )

        # 缩放因子(Transformer原论文使用)
        self.scale = math.sqrt(d_model)

        # 初始化(重要!)
        self._init_weights()

    def _init_weights(self):
        """
        嵌入层初始化策略
        使用N(0, 1/d_model)初始化
        """
        nn.init.normal_(self.embedding.weight, mean=0, std=1/math.sqrt(self.d_model))

        # 如果有padding_idx,将该位置嵌入置零
        if self.embedding.padding_idx is not None:
            with torch.no_grad():  # 禁用梯度追踪,避免初始化操作被记入计算图
                self.embedding.weight[self.embedding.padding_idx].fill_(0)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len] token IDs
        Returns:
            embeddings: [batch_size, seq_len, d_model]
        """
        # 查找嵌入并缩放
        # 缩放原因:后续要与位置编码相加,需要平衡量级
        return self.embedding(x) * self.scale

# 使用示例
vocab_size = 32000
d_model = 512
batch_size = 2
seq_len = 10

embedding_layer = TokenEmbedding(vocab_size, d_model)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
embeddings = embedding_layer(input_ids)

print(f"输入形状: {input_ids.shape}")
print(f"嵌入形状: {embeddings.shape}")
print(f"嵌入范围: [{embeddings.min():.3f}, {embeddings.max():.3f}]")

2.3 子词分词与嵌入

现代大模型使用子词( Subword )分词,如 BPE 、 WordPiece 、 SentencePiece :

Text Only
传统分词的问题:
- "playing" 和 "played" 被视为完全不同的词
- 未登录词(OOV)问题

子词分词的优势:
"playing" → ["play", "ing"]
"unhappiness" → ["un", "happiness"] 或 ["un", "happy", "ness"]

常见分词器:
├── GPT系列: BPE (Byte-Pair Encoding)
├── BERT: WordPiece
├── LLaMA/T5: SentencePiece (Unigram)
└── 中文: 字级别或BPE
Python
# 使用Hugging Face Tokenizer示例
from transformers import AutoTokenizer

# 加载GPT-2的分词器
tokenizer = AutoTokenizer.from_pretrained("gpt2")

text = "Transformer架构 revolutionized NLP"
tokens = tokenizer.tokenize(text)
print(f"分词结果: {tokens}")
# 输出: ['Trans', 'former', '架构', 'Ġre', 'volution', 'ized', 'ĠN', 'LP']

# 转换为ID
input_ids = tokenizer.encode(text)
print(f"Token IDs: {input_ids}")

# 解码
decoded = tokenizer.decode(input_ids)
print(f"解码结果: {decoded}")

位置编码机制

3.1 为什么需要位置编码

Text Only
自注意力的排列不变性:

输入: [A, B, C] → 自注意力 → 输出
输入: [B, A, C] → 自注意力 → 输出(只是顺序变了)

问题:模型无法区分"我打你"和"你打我"

解决方案:注入位置信息

3.2 正弦位置编码( Sinusoidal )

Transformer 原始论文使用的位置编码:

Text Only
PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:
- pos: 位置(0, 1, 2, ..., max_len-1)
- i: 维度索引(0, 1, 2, ..., d_model/2-1)
- d_model: 模型维度
Python
class SinusoidalPositionalEncoding(nn.Module):
    """
    正弦位置编码实现
    """
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 创建位置编码矩阵 [max_len, d_model]
        pe = torch.zeros(max_len, d_model)

        # 位置索引 [max_len, 1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # unsqueeze增加一个维度

        # 维度索引的除数项
        # 10000^(2i/d_model) = exp(2i * -log(10000) / d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() *
            (-math.log(10000.0) / d_model)
        )

        # 偶数维度用sin
        pe[:, 0::2] = torch.sin(position * div_term)

        # 奇数维度用cos
        pe[:, 1::2] = torch.cos(position * div_term)

        # 注册为buffer(不参与训练)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        seq_len = x.size(1)

        # 添加位置编码
        x = x + self.pe[:, :seq_len, :]

        return self.dropout(x)

# 可视化位置编码
import matplotlib.pyplot as plt

def visualize_positional_encoding():
    d_model = 128
    max_len = 100

    pe = SinusoidalPositionalEncoding(d_model, max_len)
    encoding = pe.pe[0].numpy()  # [max_len, d_model]

    plt.figure(figsize=(12, 6))
    plt.imshow(encoding, aspect='auto', cmap='viridis')
    plt.colorbar()
    plt.xlabel('Dimension')
    plt.ylabel('Position')
    plt.title('Sinusoidal Positional Encoding')
    plt.show()

    # 观察不同位置的编码
    plt.figure(figsize=(12, 4))
    for pos in [0, 10, 20, 50]:
        plt.plot(encoding[pos], label=f'Pos {pos}')
    plt.legend()
    plt.xlabel('Dimension')
    plt.ylabel('Value')
    plt.title('Positional Encoding at Different Positions')
    plt.show()

# 正弦位置编码的性质
"""
性质1: 唯一性
每个位置都有唯一的编码

性质2: 相对位置关系
PE(pos+k) 可以表示为 PE(pos) 的线性函数
这意味着模型可以学习相对位置

性质3: 有界性
所有值都在[-1, 1]之间,数值稳定

性质4: 外推性
可以处理训练时未见过的更长序列
"""

3.3 可学习位置编码

BERT 等模型使用可学习的位置编码:

Python
class LearnedPositionalEncoding(nn.Module):
    """
    可学习位置编码
    """
    def __init__(self, d_model, max_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 可学习的嵌入矩阵
        self.pos_embedding = nn.Embedding(max_len, d_model)

        # 初始化
        nn.init.normal_(self.pos_embedding.weight, std=0.02)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        seq_len = x.size(1)

        # 创建位置索引
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)

        # 添加位置嵌入
        x = x + self.pos_embedding(positions)

        return self.dropout(x)

3.4 旋转位置编码 (RoPE)

现代大模型( LLaMA 、 GPT-NeoX 等)使用 RoPE :

3.4.1 RoPE 的数学推导(从复数乘法出发)

Text Only
RoPE的核心目标:设计一种位置编码函数 f(x, pos),使得两个位置的内积
只依赖于相对位置差 (m - n),而非绝对位置 m 和 n:

    <f(q, m), f(k, n)> = g(q, k, m - n)

第一步:复数表示
  将向量的每两个相邻维度视为一个复数:
    z = x_{2i} + j·x_{2i+1}  (j 是虚数单位)

第二步:复数乘法 = 旋转
  在复数平面上,乘以 e^{jθ} 等价于旋转角度 θ:
    z · e^{jθ} = (x_{2i} + j·x_{2i+1})(cosθ + j·sinθ)
              = (x_{2i}·cosθ - x_{2i+1}·sinθ) + j·(x_{2i}·sinθ + x_{2i+1}·cosθ)

  写成矩阵形式:
    [x_{2i}' ]   [cosθ  -sinθ] [x_{2i}  ]
    [x_{2i+1}'] = [sinθ   cosθ] [x_{2i+1}]

第三步:位置相关的旋转角度
  对位置 m 的第 i 对维度,旋转角度 θ_i(m) = m · ω_i
  其中 ω_i = 1 / 10000^{2i/d}(频率随维度指数递减)

第四步:验证相对位置性质
  对位置 m 的 query 和位置 n 的 key(在第 i 对维度上):
    f(q, m) = q · e^{j·m·ω_i}
    f(k, n) = k · e^{j·n·ω_i}

  内积(复数):
    Re[f(q, m) · conj(f(k, n))]
    = Re[q · e^{j·m·ω_i} · conj(k · e^{j·n·ω_i})]
    = Re[q · conj(k) · e^{j·(m-n)·ω_i}]

  结果只依赖 (m-n),证毕。✓

第五步:推广到全维度
  对 d 维向量,每两维一组,共 d/2 组,每组用不同频率 ω_i 旋转:

    RoPE(x, pos) = [R(pos·ω_0)·x_{0:1}, R(pos·ω_1)·x_{2:3}, ..., R(pos·ω_{d/2-1})·x_{d-2:d-1}]

  其中 R(θ) 是 2×2 旋转矩阵。

3.4.2 RoPE 实现

Python
class RotaryPositionalEmbedding(nn.Module):
    """
    旋转位置编码 (RoPE)
    通过旋转矩阵注入相对位置信息
    """
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()

        # 计算旋转角度
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # 预计算位置编码
        t = torch.arange(max_seq_len)
        freqs = torch.einsum('i,j->ij', t, inv_freq)  # [max_seq_len, dim/2]

        # 复数形式: cos + i*sin
        self.register_buffer('cos_cached', freqs.cos()[None, None, :, :])  # [1, 1, seq_len, dim/2]
        self.register_buffer('sin_cached', freqs.sin()[None, None, :, :])  # [1, 1, seq_len, dim/2]

    def forward(self, x, seq_len=None):
        """
        Args:
            x: [batch, heads, seq_len, head_dim]
        """
        if seq_len is None:
            seq_len = x.shape[2]

        cos = self.cos_cached[:, :, :seq_len, :]
        sin = self.sin_cached[:, :, :seq_len, :]

        return self.apply_rotary_pos_emb(x, cos, sin)

    def apply_rotary_pos_emb(self, x, cos, sin):
        """
        应用旋转位置编码

        Args:
            x: [batch, heads, seq_len, head_dim]
            cos: [1, 1, seq_len, head_dim/2]
            sin: [1, 1, seq_len, head_dim/2]

        Returns:
            [batch, heads, seq_len, head_dim]
        """
        # 将x分成两部分(偶数维和奇数维)
        x1, x2 = x[..., ::2], x[..., 1::2]

        # 旋转: [x1, x2] @ [[cos, -sin], [sin, cos]]
        # 即: x1' = x1 * cos - x2 * sin
        #     x2' = x1 * sin + x2 * cos
        rotated = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos
        ], dim=-1)

        # 将最后两维展平,恢复原始形状
        return rotated.flatten(-2)

RoPE 旋转位置编码可视化

上图展示了 RoPE (旋转位置编码)的工作原理。 RoPE 通过旋转矩阵将位置信息注入到查询和键向量中,每个维度对在不同的旋转速度下工作,类似于不同速度的时钟指针。

RoPE 2D 投影可视化

这张图展示了 RoPE 在 2D 平面上的投影,不同颜色代表不同位置的向量,它们从原点发出,展示了实部和虚部之间的关系。 RoPE 的优势在于能够自然地编码相对位置信息,并且具有良好的外推性能。


自注意力机制详解

4.1 直觉理解

自注意力的核心思想:让序列中每个位置都"看到"其他所有位置,并学习该关注哪些位置。

Text Only
以"猫坐在垫子上,它很舒服"为例:

处理"它"时,自注意力权重分布(示意):
猫  坐  在  垫子 上  ,  它  很 舒服
0.45 0.05 0.02 0.15 0.03 0.01 0.10 0.04 0.15

→ 模型学会了"它"主要指代"猫"(最高权重),次要关联"垫子"和"舒服"

4.2 Q-K-V 的数学推导

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

三个矩阵的角色: - \(Q\)( Query ):当前位置发出的"提问" - \(K\)( Key ):每个位置提供的"索引标签" - \(V\)( Value ):每个位置存储的"实际内容"

Text Only
类比图书馆:
Q = 你的搜索关键词("深度学习入门")
K = 每本书的标题/关键词索引
V = 每本书的实际内容

检索流程:
1. 计算 Q @ K^T → 搜索词与每本书标题的匹配分数
2. softmax → 将分数归一化为概率
3. 概率 @ V → 加权提取最相关书的内容
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    query: torch.Tensor,   # [batch, seq_q, d_k]
    key: torch.Tensor,     # [batch, seq_k, d_k]
    value: torch.Tensor,   # [batch, seq_k, d_v]
    mask: torch.Tensor = None,
    dropout: nn.Dropout = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    缩放点积注意力

    关键点:
    1. 为什么除以 sqrt(d_k)?(完整推导)

       假设 q 和 k 的每个分量独立同分布,均值为0、方差为1:
         q = [q_1, q_2, ..., q_{d_k}],  q_i ~ (0, 1)
         k = [k_1, k_2, ..., k_{d_k}],  k_i ~ (0, 1)

       点积 q·k = Σ_{i=1}^{d_k} q_i * k_i

       每一项 q_i * k_i 的方差:
         Var(q_i * k_i) = E[q_i²] * E[k_i²] - (E[q_i] * E[k_i])²
                        = 1 * 1 - 0 = 1

       因为各项独立,所以:
         Var(q·k) = Σ Var(q_i * k_i) = d_k

       当 d_k=64 时,点积标准差 ≈ 8;d_k=128 时 ≈ 11.3
       这些大值会将 softmax 推入饱和区(梯度接近 0)

       除以 √d_k 后:
         Var(q·k / √d_k) = Var(q·k) / d_k = d_k / d_k = 1

       方差归一化为 1,softmax 的输入保持在合理范围,梯度可以正常流动。

    2. mask 的两种用途:
       - padding mask: 屏蔽填充位置
       - causal mask: 屏蔽未来位置(解码器)
    """
    d_k = query.size(-1)

    # Step 1: QK^T / sqrt(d_k) → [batch, seq_q, seq_k]
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # Step 2: 应用 mask(将需要屏蔽的位置设为 -inf)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 3: Softmax → 权重矩阵
    attn_weights = F.softmax(scores, dim=-1)

    # Step 4: Dropout(训练时)
    if dropout is not None:
        attn_weights = dropout(attn_weights)

    # Step 5: 加权求和 → [batch, seq_q, d_v]
    output = torch.matmul(attn_weights, value)

    return output, attn_weights

# ---------- 验证 ----------
batch, seq_len, d_k = 2, 5, 64
Q = torch.randn(batch, seq_len, d_k)
K = torch.randn(batch, seq_len, d_k)
V = torch.randn(batch, seq_len, d_k)

# 无 mask
out, weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {out.shape}")           # [2, 5, 64]
print(f"权重形状: {weights.shape}")        # [2, 5, 5]
print(f"权重行和: {weights.sum(-1)}")      # 每行和为 1.0

# 因果 mask(解码器用)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))  # 下三角矩阵
out_masked, weights_masked = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
print(f"\n因果mask权重(第0行只看自己,第4行看所有位置):")
print(weights_masked[0].detach())

4.3 注意力复杂度分析

Text Only
时间复杂度: O(n² · d)
  - 其中 n = 序列长度, d = 模型维度
  - 主要来自 QK^T 的矩阵乘法

空间复杂度: O(n²)
  - 需要存储完整的注意力权重矩阵

序列长度 →  注意力计算量(d=1024)
128         → 16.7M
1,024       → 1.07B
8,192       → 68.7B
128,000     → 16.8T  ← 这就是长上下文的挑战

解决方案:
├── FlashAttention: 利用GPU内存层级优化,O(n²)复杂度不变但实际速度快2-4倍
├── GQA (Grouped-Query Attention): 减少KV头数,降低内存
├── MQA (Multi-Query Attention): 所有头共享一组KV
├── Ring Attention: 分布式环形通信处理超长序列
└── 稀疏注意力: 只计算部分位置对的注意力

多头注意力机制

5.1 为什么需要多头

单头注意力的问题:一个注意力头只能学到一种"关注模式"。而自然语言有多种关系维度(语法、语义、共指、因果等)。

Text Only
多头的直觉:用不同的"眼睛"看同一句话

Head 1: 可能学会了语法依赖(主语→动词)
Head 2: 可能学会了共指关系(代词→先行词)
Head 3: 可能学会了修饰关系(形容词→名词)
Head 4: 可能学会了长距离依赖(问题→答案)
...

5.2 多头注意力实现

\[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O \]
\[ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \]
Python
class MultiHeadAttention(nn.Module):
    """
    多头注意力机制

    参数量分析(d_model=1024, h=16, d_k=64):
    W_Q: 1024×1024 = 1M
    W_K: 1024×1024 = 1M
    W_V: 1024×1024 = 1M
    W_O: 1024×1024 = 1M
    总计: 4M 参数
    """

    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model 必须被 n_heads 整除"  # assert断言:条件False时抛出AssertionError

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 每个头的维度

        # 四个投影矩阵
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        query: torch.Tensor,   # [batch, seq_q, d_model]
        key: torch.Tensor,     # [batch, seq_k, d_model]
        value: torch.Tensor,   # [batch, seq_k, d_model]
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        batch_size = query.size(0)

        # 1. 线性投影 → [batch, seq, d_model]
        Q = self.W_Q(query)
        K = self.W_K(key)
        V = self.W_V(value)

        # 2. 分头: [batch, seq, d_model] → [batch, n_heads, seq, d_k]
        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # view重塑张量形状
        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # 3. 缩放点积注意力
        d_k = self.d_k
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(0).unsqueeze(0)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 4. 加权求和 → [batch, n_heads, seq_q, d_k]
        context = torch.matmul(attn_weights, V)

        # 5. 合并多头: [batch, seq_q, d_model]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # 6. 输出投影
        output = self.W_O(context)

        return output

# 验证
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(2, 10, 512)

# Self-attention: Q=K=V=x
out = mha(x, x, x)
print(f"Self-attention 输出: {out.shape}")  # [2, 10, 512]

# Cross-attention: Q来自decoder, K/V来自encoder
enc_out = torch.randn(2, 20, 512)
dec_in = torch.randn(2, 10, 512)
cross_out = mha(dec_in, enc_out, enc_out)
print(f"Cross-attention 输出: {cross_out.shape}")  # [2, 10, 512]

5.3 GQA 与 MQA :现代大模型的注意力优化

Text Only
标准 MHA (Multi-Head Attention):
  Q: [batch, n_heads, seq, d_k]    ← 每个头有独立的Q
  K: [batch, n_heads, seq, d_k]    ← 每个头有独立的K
  V: [batch, n_heads, seq, d_k]    ← 每个头有独立的V
  KV cache 大小 = 2 × n_heads × seq × d_k

MQA (Multi-Query Attention, GPT-J):
  Q: [batch, n_heads, seq, d_k]    ← 每个头有独立的Q
  K: [batch, 1, seq, d_k]          ← 所有头共享一组K
  V: [batch, 1, seq, d_k]          ← 所有头共享一组V
  KV cache 大小 = 2 × 1 × seq × d_k  → 减少 n_heads 倍

GQA (Grouped-Query Attention, LLaMA-2/3):
  Q: [batch, n_heads, seq, d_k]    ← 每个头有独立的Q
  K: [batch, n_kv_heads, seq, d_k] ← 每组共享一组K
  V: [batch, n_kv_heads, seq, d_k] ← 每组共享一组V
  KV cache 大小 = 2 × n_kv_heads × seq × d_k

示例 (LLaMA-3-70B): n_heads=64, n_kv_heads=8
  → 每8个Q头共享1组KV → KV cache减少8倍,性能接近MHA

前馈神经网络

6.1 标准 FFN

每个 Transformer 层中的全连接前馈网络( Position-wise FFN ):

\[ \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2 \]
Python
class FeedForward(nn.Module):
    """标准FFN,4倍扩展"""
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model  # 默认4倍扩展

        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # [batch, seq, d_model] → [batch, seq, d_ff] → [batch, seq, d_model]
        return self.w2(self.dropout(F.relu(self.w1(x))))

6.2 现代激活函数: GeLU 与 SwiGLU

Python
# GeLU (GPT-2/BERT 使用)
# GELU(x) = x · Φ(x),其中 Φ 是标准正态分布的CDF
# 比 ReLU 更平滑,不会在 x=0 处产生不可导点

class GeLUFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w2(self.dropout(F.gelu(self.w1(x))))

# SwiGLU (LLaMA/Qwen/Mistral 使用)
# SwiGLU(x, W, V, W2) = (Swish(xW) ⊙ xV) W2
# 引入门控机制,效果优于 GeLU,但参数量增加 50%

class SwiGLUFeedForward(nn.Module):
    """LLaMA-style SwiGLU FFN"""
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        # LLaMA 使用 8/3 * d_model 作为 d_ff 以保持参数量不变
        d_ff = d_ff or int(8 / 3 * d_model)
        # 取最接近的 256 的倍数(GPU对齐优化)
        d_ff = ((d_ff + 255) // 256) * 256

        self.w1 = nn.Linear(d_model, d_ff, bias=False)  # gate projection
        self.w3 = nn.Linear(d_model, d_ff, bias=False)  # up projection
        self.w2 = nn.Linear(d_ff, d_model, bias=False)  # down projection
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Swish(xW1) ⊙ xW3 → 门控机制
        gate = F.silu(self.w1(x))  # SiLU = Swish(x) = x * sigmoid(x)
        up = self.w3(x)
        return self.w2(self.dropout(gate * up))

# 对比
d_model = 1024
ffn_relu = FeedForward(d_model)
ffn_gelu = GeLUFeedForward(d_model)
ffn_swiglu = SwiGLUFeedForward(d_model)

for name, module in [("ReLU FFN", ffn_relu), ("GeLU FFN", ffn_gelu), ("SwiGLU FFN", ffn_swiglu)]:
    params = sum(p.numel() for p in module.parameters())
    print(f"{name}: {params:,} 参数")
# ReLU FFN: 8,393,728 参数
# GeLU FFN: 8,393,728 参数
# SwiGLU FFN: 8,650,752 参数 (d_ff调整后参数量接近)

层归一化与残差连接

7.1 为什么需要归一化

深层网络中,每层输出的分布会不断漂移( Internal Covariate Shift ),导致训练不稳定。归一化将输出拉回标准分布。

7.2 LayerNorm vs RMSNorm

Python
class LayerNorm(nn.Module):
    """
    标准层归一化 (BERT/GPT-2 使用)
    LayerNorm(x) = γ · (x - μ) / √(σ² + ε) + β
    """
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))   # 可学习缩放
        self.beta = nn.Parameter(torch.zeros(d_model))    # 可学习偏移
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

class RMSNorm(nn.Module):
    """
    RMS层归一化 (LLaMA/Qwen/Mistral 使用)
    RMSNorm(x) = γ · x / RMS(x)
    其中 RMS(x) = √(1/n · Σx²)

    优势:去掉了均值计算和偏移参数 β
    → 计算更快(约10-15%),效果相当
    """
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return self.gamma * x / rms

# 对比
x = torch.randn(2, 10, 512)
ln = LayerNorm(512)
rmsn = RMSNorm(512)

out_ln = ln(x)
out_rmsn = rmsn(x)
print(f"LayerNorm 输出统计: mean={out_ln.mean():.4f}, std={out_ln.std():.4f}")
print(f"RMSNorm 输出统计:   mean={out_rmsn.mean():.4f}, std={out_rmsn.std():.4f}")

7.3 Pre-Norm vs Post-Norm

Text Only
Post-Norm (原始Transformer, BERT):
  x → Attention → Add(x, ·) → LayerNorm → FFN → Add(·, ·) → LayerNorm

  优点:最终层的表示经过归一化,理论性质更好
  缺点:深层训练不稳定,需要 warm-up

Pre-Norm (GPT-2, LLaMA, 现代模型):
  x → LayerNorm → Attention → Add(x, ·) → LayerNorm → FFN → Add(·, ·)

  优点:梯度直接通过残差传播,训练更稳定
  缺点:理论上收敛性稍差,但实践中更常用

现代大模型几乎都使用 Pre-Norm + RMSNorm 的组合。
Python
class TransformerBlock(nn.Module):
    """Pre-Norm Transformer Block (LLaMA-style)"""

    def __init__(self, d_model: int, n_heads: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        self.norm1 = RMSNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm2 = RMSNorm(d_model)
        self.ffn = SwiGLUFeedForward(d_model, d_ff, dropout)

    def forward(self, x, mask=None):
        # Pre-Norm + Residual
        x_norm = self.norm1(x)
        x = x + self.attn(x_norm, x_norm, x_norm, mask)
        x = x + self.ffn(self.norm2(x))
        return x

编解码器架构对比

8.1 三种主流架构

这是理解现代大模型最关键的知识点之一:

Text Only
┌────────────────────────────────────────────────────────────────┐
│                    三种 Transformer 架构                        │
├──────────────┬──────────────┬──────────────┬──────────────────┤
│              │ Encoder-Only │ Decoder-Only │ Encoder-Decoder  │
├──────────────┼──────────────┼──────────────┼──────────────────┤
│ 代表模型     │ BERT         │ GPT/LLaMA    │ T5/BART          │
│ 注意力类型   │ 双向全注意力  │ 因果注意力    │ 双向+因果+交叉   │
│ 训练目标     │ MLM + NSP    │ 下一token预测 │ Seq2Seq          │
│ 擅长任务     │ 分类/NER/QA  │ 生成/对话     │ 翻译/摘要        │
│ 上下文       │ 看到全部输入  │ 只看到左侧    │ 编码器全部,解码器左侧│
│ 当前趋势     │ 逐渐减少     │ 主流(GPT时代) │ 特定场景使用     │
└──────────────┴──────────────┴──────────────┴──────────────────┘

为什么 Decoder-Only 成为主流?
1. 统一的训练目标:所有任务都可以转化为"文本生成"
2. 规模效应更好:参数量越大,涌现能力越强
3. 少样本能力:通过 prompt 即可适配新任务,无需微调
4. 工程简单:只需一个模型处理所有任务

8.2 因果注意力掩码

Decoder-Only 模型的核心:确保位置 \(i\) 只能看到位置 \(0, 1, \ldots, i\)(不泄露未来信息)。

Python
def create_causal_mask(seq_len: int) -> torch.Tensor:
    """
    创建因果注意力掩码

    例如 seq_len=4:
    [[1, 0, 0, 0],   ← 位置0只看自己
     [1, 1, 0, 0],   ← 位置1看0和1
     [1, 1, 1, 0],   ← 位置2看0、1、2
     [1, 1, 1, 1]]   ← 位置3看所有
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# 训练时:一次前向传播同时预测所有位置
# 因为有因果mask,每个位置只看到左侧,等价于自回归

Transformer 的训练

9.1 训练目标

Text Only
Decoder-Only (如 GPT) 的训练:

输入:  [BOS] The  cat  sat  on  the  mat
标签:  The   cat  sat  on   the mat  [EOS]

损失函数: Cross-Entropy Loss
L = -1/T Σ_t log P(x_t | x_{<t})

即最大化每个位置正确预测下一个token的对数概率。
整个序列只需一次前向传播 (teacher forcing)。

9.2 学习率调度

Transformer 训练中最关键的超参数管理:

Python
class WarmupCosineScheduler:
    """
    Warmup + Cosine Decay 学习率调度
    几乎所有现代大模型都使用这种方式

    阶段1 (warmup): 学习率从0线性增长到max_lr
    阶段2 (cosine): 学习率从max_lr余弦衰减到min_lr
    """
    def __init__(
        self,
        optimizer,
        warmup_steps: int,
        total_steps: int,
        max_lr: float = 3e-4,
        min_lr: float = 1e-5,
    ):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.step_count = 0

    def step(self):
        self.step_count += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_lr(self):
        if self.step_count < self.warmup_steps:
            # 线性 warmup
            return self.max_lr * self.step_count / self.warmup_steps
        else:
            # 余弦衰减
            progress = (self.step_count - self.warmup_steps) / (
                self.total_steps - self.warmup_steps
            )
            return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
                1 + math.cos(math.pi * progress)
            )

# 训练循环骨架
"""
model = DecoderOnlyTransformer(...)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=2000, total_steps=100000)

for batch in dataloader:
    input_ids = batch['input_ids']                      # [B, T]
    targets = input_ids[:, 1:]                          # 右移一位
    logits = model(input_ids[:, :-1])                   # [B, T-1, vocab]
    loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # 梯度裁剪
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()
"""

Transformer 的推理

10.1 自回归生成

Text Only
训练是并行的(一次处理整个序列),
推理是自回归的(一次生成一个token):

Step 0: 输入 [BOS] → 预测 "The"
Step 1: 输入 [BOS, The] → 预测 "cat"
Step 2: 输入 [BOS, The, cat] → 预测 "sat"
...

问题:每一步都要重新计算之前所有位置的注意力 → 巨大浪费
解决:KV Cache

10.2 KV Cache

KV Cache 是大模型推理最重要的优化技术:

Python
"""KV Cache 原理示意"""

class CachedMultiHeadAttention(nn.Module):
    """带KV Cache的多头注意力(推理用)"""

    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, kv_cache=None):
        """
        Args:
            x: [batch, 1, d_model] — 推理时只有当前新token
            kv_cache: (cached_K, cached_V) — 之前步骤的K/V
        Returns:
            output, new_kv_cache
        """
        B = x.size(0)

        # 只为当前 token 计算 Q/K/V
        Q = self.W_Q(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
        K_new = self.W_K(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
        V_new = self.W_V(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)

        if kv_cache is not None:
            K_cached, V_cached = kv_cache
            # 拼接历史 KV → [batch, n_heads, seq_so_far+1, d_k]
            K = torch.cat([K_cached, K_new], dim=2)
            V = torch.cat([V_cached, V_new], dim=2)
        else:
            K, V = K_new, V_new

        # 只计算当前 Q 与所有 K 的注意力 → O(1×S) 而非 O(S×S)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)

        out = out.transpose(1, 2).contiguous().view(B, 1, self.d_model)
        output = self.W_O(out)

        return output, (K, V)  # 返回更新后的 cache

"""
无 KV Cache: 生成 T 个token → 总计算量 O(T² · d)
有 KV Cache: 生成 T 个token → 总计算量 O(T · d) + 缓存内存
  速度提升: ~T倍(序列越长,加速越明显)

代价: 需要额外内存存储 KV Cache
  LLaMA-2-7B, seq_len=4096:
  KV cache = 2(K+V) × 32(layers) × 32(heads) × 4096(seq) × 128(d_k) × 2(bytes/fp16)
           ≈ 2GB
"""

10.3 解码策略

Python
def generate(
    model, tokenizer, prompt: str,
    max_new_tokens: int = 100,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 0.9,
):
    """
    完整的文本生成函数,支持多种解码策略
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    kv_cache = None
    generated = []

    for _ in range(max_new_tokens):
        # 只输入最新 token(有 KV Cache 时)
        if kv_cache is not None:
            x = input_ids[:, -1:]
        else:
            x = input_ids

        logits, kv_cache = model(x, kv_cache=kv_cache)
        logits = logits[:, -1, :]  # 只取最后一个位置

        # Temperature scaling
        logits = logits / temperature

        # Top-K: 只保留概率最高的K个token
        if top_k > 0:
            topk_logits, topk_indices = torch.topk(logits, top_k)
            logits = torch.full_like(logits, float('-inf'))
            logits.scatter_(1, topk_indices, topk_logits)

        # Top-P (Nucleus Sampling)
        probs = F.softmax(logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # 移除累积概率超过 top_p 的token
        remove_mask = cumulative_probs - sorted_probs > top_p
        sorted_probs[remove_mask] = 0
        sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

        # 采样
        next_token = torch.multinomial(sorted_probs, 1)
        next_token = sorted_indices.gather(1, next_token)

        if next_token.item() == tokenizer.eos_token_id:
            break

        generated.append(next_token.item())
        input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(generated)

延伸阅读


最后更新日期: 2025-07-11 适用版本: LLM 学习教程 v2025