实践:手写 Transformer¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
目标:从零开始实现一个完整的 Transformer 模型,不依赖任何预定义的 Transformer 模块。
重要原则:不要复制粘贴! 每一行代码都自己写,这是建立真正理解的唯一途径。
项目结构¶
所有代码文件(
transformer.py,train.py,data.py,config.py,requirements.txt)的完整内容已包含在本文档的 附录:完整代码参考 部分。请先自己动手实现,遇到困难时再参考。
任务说明¶
任务 1 :实现基础组件(预计 2-3 小时)¶
你需要在 transformer.py 中实现以下组件:
1.1 多头注意力( MultiHeadAttention )¶
class MultiHeadAttention(nn.Module):
"""
实现要点:
1. 初始化Q, K, V的投影矩阵
2. 实现分头逻辑(reshape + transpose)
3. 实现缩放点积注意力
4. 实现拼接和输出投影
输入输出shape:
- 输入: [batch_size, seq_len, d_model]
- 输出: [batch_size, seq_len, d_model]
"""
pass
验证方法:
# 测试代码
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512) # batch=2, seq=10
out = mha(x, x, x)
assert out.shape == (2, 10, 512), f"形状错误: {out.shape}" # assert断言:条件False时抛出AssertionError
print("✓ MultiHeadAttention 测试通过")
1.2 前馈网络( FeedForward )¶
class FeedForward(nn.Module):
"""
实现要点:
1. 两个线性层
2. 中间使用ReLU激活
3. d_model -> 4*d_model -> d_model
输入输出shape:
- 输入: [batch_size, seq_len, d_model]
- 输出: [batch_size, seq_len, d_model]
"""
pass
1.3 编码器层( EncoderLayer )¶
class EncoderLayer(nn.Module):
"""
实现要点:
1. 多头自注意力
2. 残差连接 + LayerNorm
3. 前馈网络
4. 残差连接 + LayerNorm
注意:LayerNorm在注意力/FFN之前或之后都可以
原始论文是之后,但之前更稳定(Pre-LN)
"""
pass
1.4 解码器层( DecoderLayer )¶
class DecoderLayer(nn.Module):
"""
实现要点:
1. 掩码多头自注意力
2. 残差连接 + LayerNorm
3. 交叉注意力(Encoder-Decoder Attention)
4. 残差连接 + LayerNorm
5. 前馈网络
6. 残差连接 + LayerNorm
"""
pass
任务 2 :组装完整模型(预计 1-2 小时)¶
2.1 位置编码¶
class PositionalEncoding(nn.Module):
"""
实现正弦位置编码:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
或者使用可学习的位置编码(更简单)
"""
pass
2.2 完整 Transformer¶
class Transformer(nn.Module):
"""
组装所有组件:
1. Embedding层
2. 位置编码
3. N个Encoder层
4. N个Decoder层
5. 输出投影到词表
方法:
- forward(src, tgt): 训练前向传播
- encode(src): 编码源序列
- decode(tgt, memory): 解码生成
- generate(src, max_len): 自回归生成
"""
pass
任务 3 :训练一个简单的语言模型(预计 2-3 小时)¶
3.1 准备数据¶
我们使用字符级语言模型任务:输入一段文本,预测下一个字符。
# 使用莎士比亚文本或简单儿歌
text = """
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take Arms against a Sea of troubles,
"""
3.2 实现训练循环¶
def train_epoch(model, dataloader, optimizer, criterion):
"""
实现一个epoch的训练:
1. 遍历dataloader
2. 前向传播
3. 计算损失
4. 反向传播
5. 更新参数
6. 记录loss
"""
pass
3.3 实现生成函数¶
def generate_text(model, start_text, max_length=100):
"""
自回归生成文本:
1. 编码start_text
2. 循环生成每个字符
3. 将生成的字符加入输入,继续生成
4. 直到生成max_length或结束符
"""
pass
实现提示¶
关于张量形状¶
时刻检查张量形状!这是调试 Transformer 的关键。
# 常用打印调试
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"scores shape: {scores.shape}")
关于掩码¶
Padding Mask:处理变长序列
# 假设pad_idx = 0
# src: [batch, seq_len]
src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len] # unsqueeze增加一个维度
Causal Mask:防止看到未来信息
def create_causal_mask(size):
"""创建下三角掩码"""
mask = torch.tril(torch.ones(size, size))
return mask # [size, size]
关于 LayerNorm¶
# PyTorch内置LayerNorm
self.norm = nn.LayerNorm(d_model)
# 使用
out = self.norm(x + sublayer(x)) # 残差连接 + LayerNorm
调试检查清单¶
当你遇到问题时,按顺序检查:
- 形状检查:所有张量形状是否符合预期?
- 设备检查:所有张量是否在同一设备( CPU/GPU )?
- 掩码检查:掩码是否正确应用?(打印掩码看看)
- 梯度检查:是否有梯度消失/爆炸?(打印梯度范数)
- 数值检查:中间结果是否有 nan/inf ?
进阶挑战¶
完成基础任务后,尝试这些挑战:
挑战 1 :实现 Copy Task¶
训练模型复制输入序列: - 输入:[1, 2, 3, 4, 5] - 输出:[1, 2, 3, 4, 5]
这是验证 Transformer 是否正确实现的最简单任务。
挑战 2 :实现 Sorted Copy Task¶
训练模型对输入排序后复制: - 输入:[3, 1, 4, 2, 5] - 输出:[1, 2, 3, 4, 5]
这需要模型学习排序算法。
挑战 3 :添加学习率调度¶
实现 Warmup + Cosine Annealing 学习率调度:
挑战 4 :实现 Beam Search 生成¶
实现 Beam Search 解码,而不是 Greedy 解码:
预期结果¶
最小可运行版本¶
你应该能够: 1. 成功训练模型( loss 下降) 2. 生成看起来像样(但可能无意义)的文本 3. 在 Copy Task 上达到接近 100%准确率
良好版本¶
- 生成有一定语法结构的文本
- 训练稳定,没有 nan/inf
- 代码结构清晰,有适当注释
优秀版本¶
- 完成所有进阶挑战
- 可视化注意力权重
- 分析不同层的注意力模式
常见问题¶
Q1: 模型不收敛, loss 不下降¶
可能原因: - 学习率太大或太小 - 没有使用适当的学习率 warmup - 梯度消失(检查梯度范数) - 掩码应用错误
解决方法: - 先用很小的学习率( 1e-4 )测试 - 打印每层的梯度范数 - 简化模型(减少层数、头数)先验证正确性
Q2: 生成结果全是同一个字符¶
可能原因: - 模型陷入局部最优 - 温度参数( temperature )设置不当 - 生成时使用了 argmax 而不是 sampling
解决方法: - 使用 temperature sampling - 添加 top-k 或 top-p 采样
Q3: 显存不足¶
解决方法: - 减小 batch_size - 减小模型维度( d_model ) - 减小序列长度( seq_len ) - 使用梯度累积
参考资源¶
必读: - The Annotated Transformer - Harvard 的 Transformer 详细注释实现
参考(遇到困难时再看): - PyTorch 官方 Transformer 实现:torch.nn.Transformer - Hugging Face Transformers 库
警告:不要直接复制这些代码!先自己尝试,遇到困难时再参考思路。
提交要求¶
完成项目后,你应该提交:
- 完整的代码:
transformer.py,train.py,data.py - 训练日志: loss 曲线截图
- 生成示例:模型生成的文本样例
- 学习笔记:记录你遇到的困难和解决方法
开始吧¶
现在,关闭所有 AI 助手,打开你的 IDE ,从空白文件开始。
记住:犯错是学习的一部分。每一个你亲手解决的 bug ,都会让你变得更强。
祝你好运!🚀
附录:完整代码参考¶
⚠️ 警告:请先自己尝试实现,遇到困难时再参考以下代码。
依赖包( requirements.txt )¶
# 基础依赖
torch>=2.0.0
numpy>=1.24.0
# 可视化
matplotlib>=3.7.0
seaborn>=0.12.0
# 进度条
tqdm>=4.65.0
# 可选:用于下载数据
requests>=2.31.0
配置文件( config.py )¶
"""
Transformer模型配置
你可以根据需要调整这些参数
"""
# 导入torch(用于device检测)
import torch
class Config:
"""模型配置类"""
# 模型架构参数
vocab_size = 256 # 字符级词表大小(ASCII)
d_model = 256 # 模型维度(embedding维度)
num_heads = 8 # 注意力头数
num_encoder_layers = 4 # Encoder层数
num_decoder_layers = 4 # Decoder层数
d_ff = 512 # 前馈网络中间层维度(通常是2*d_model或4*d_model)
dropout = 0.1 # Dropout概率
max_seq_len = 128 # 最大序列长度
# 训练参数
batch_size = 32
learning_rate = 1e-4
num_epochs = 50
warmup_steps = 4000 # 学习率warmup步数
# 生成参数
max_gen_len = 100 # 最大生成长度
temperature = 1.0 # 采样温度
top_k = 50 # Top-k采样
# 设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 小模型配置(用于快速测试)
class SmallConfig(Config):
"""小型配置,用于快速验证代码正确性"""
d_model = 64
num_heads = 4
num_encoder_layers = 2
num_decoder_layers = 2
d_ff = 128
batch_size = 8
num_epochs = 10
# 中模型配置
class MediumConfig(Config):
"""中型配置,平衡效果和速度"""
d_model = 512
num_heads = 8
num_encoder_layers = 6
num_decoder_layers = 6
d_ff = 2048
batch_size = 16
数据加载( data.py )¶
"""
数据加载和处理
提供字符级语言模型的数据加载
"""
import torch
from torch.utils.data import Dataset, DataLoader
class CharDataset(Dataset):
"""
字符级数据集
将文本转换为字符序列,用于语言模型训练
"""
def __init__(self, text, seq_len):
"""
Args:
text: 输入文本字符串
seq_len: 序列长度
"""
self.text = text
self.seq_len = seq_len
# 构建字符到索引的映射
self.chars = sorted(list(set(text)))
self.vocab_size = len(self.chars)
self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
# 将文本转换为索引序列
self.data = [self.char_to_idx[ch] for ch in text]
def __len__(self):
return len(self.data) - self.seq_len
def __getitem__(self, idx):
"""
获取一个样本
对于语言模型,输入和目标错开一位:
- 输入: [x0, x1, x2, ..., xn-1]
- 目标: [x1, x2, x3, ..., xn]
"""
# 获取序列
x = self.data[idx:idx + self.seq_len]
y = self.data[idx + 1:idx + self.seq_len + 1]
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
def decode(self, indices):
"""将索引序列转换为字符串"""
return ''.join([self.idx_to_char[idx] for idx in indices])
def encode(self, text):
"""将字符串转换为索引序列"""
return [self.char_to_idx[ch] for ch in text]
def get_shakespeare_data():
"""
获取莎士比亚文本数据
如果本地没有,会下载一个小片段
"""
text = """
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take Arms against a Sea of troubles,
And by opposing end them: to die, to sleep;
No more; and by a sleep, to say we end
The heart-ache, and the thousand natural shocks
That Flesh is heir to? 'Tis a consummation
Devoutly to be wished. To die, to sleep,
perchance to Dream; aye, there's the rub,
For in that sleep of death, what dreams may come,
When we have shuffled off this mortal coil,
Must give us pause. There's the respect
That makes Calamity of so long life:
For who would bear the Whips and Scorns of time,
The Oppressor's wrong, the proud man's Contumely,
The pangs of despised Love, the Law's delay,
The insolence of Office, and the spurns
That patient merit of the unworthy takes,
When he himself might his Quietus make
With a bare Bodkin? Who would Fardels bear,
To grunt and sweat under a weary life,
But that the dread of something after death,
The undiscovered country, from whose bourn
No traveller returns, puzzles the will,
And makes us rather bear those ills we have,
Than fly to others that we know not of.
Thus conscience does make cowards of us all,
And thus the native hue of Resolution
Is sicklied o'er, with the pale cast of Thought,
And enterprises of great pitch and moment,
With this regard their Currents turn awry,
And lose the name of Action. Soft you now,
The fair Ophelia? Nymph, in thy Orisons
Be all my sins remember'd.
"""
return text
def get_simple_data():
"""
获取简单的儿歌数据
更容易学习,适合验证模型是否正确
"""
text = """
Twinkle, twinkle, little star,
How I wonder what you are!
Up above the world so high,
Like a diamond in the sky.
Twinkle, twinkle, little star,
How I wonder what you are!
When the blazing sun is gone,
When he nothing shines upon,
Then you show your little light,
Twinkle, twinkle, all the night.
Twinkle, twinkle, little star,
How I wonder what you are!
Then the traveler in the dark
Thanks you for your tiny spark,
He could not see which way to go,
If you did not twinkle so.
Twinkle, twinkle, little star,
How I wonder what you are!
In the dark blue sky you keep,
And often through my curtains peep,
For you never shut your eye,
Till the sun is in the sky.
Twinkle, twinkle, little star,
How I wonder what you are!
As your bright and tiny spark,
Lights the traveler in the dark,
Though I know not what you are,
Twinkle, twinkle, little star.
Twinkle, twinkle, little star,
How I wonder what you are!
"""
return text
def create_dataloaders(text, seq_len, batch_size, train_split=0.9):
"""
创建训练和验证DataLoader
Args:
text: 输入文本
seq_len: 序列长度
batch_size: 批次大小
train_split: 训练集比例
Returns:
train_loader, val_loader, dataset
"""
# 创建数据集
dataset = CharDataset(text, seq_len)
# 划分训练集和验证集
train_size = int(len(dataset) * train_split)
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size]
)
# 创建DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=True
)
return train_loader, val_loader, dataset
# 简单的Copy Task数据集(用于验证模型)
class CopyDataset(Dataset):
"""
Copy Task数据集
输入: [start_token, x1, x2, ..., xn, end_token]
输出: [x1, x2, ..., xn, end_token, pad_token]
模型需要学会复制输入序列
"""
def __init__(self, num_samples=1000, seq_len=10, vocab_size=10):
self.num_samples = num_samples
self.seq_len = seq_len
self.vocab_size = vocab_size
# 特殊token
self.start_token = vocab_size
self.end_token = vocab_size + 1
self.pad_token = vocab_size + 2
self.total_vocab = vocab_size + 3
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 随机生成序列
seq = torch.randint(0, self.vocab_size, (self.seq_len,))
# 构建输入: [start, x1, x2, ..., xn]
src = torch.cat([
torch.tensor([self.start_token]),
seq
])
# 构建目标: [x1, x2, ..., xn, end]
tgt = torch.cat([
seq,
torch.tensor([self.end_token])
])
return src, tgt
Transformer 模型实现( transformer.py )¶
"""
Transformer模型实现
任务:从零实现Transformer的所有组件
重要:不要复制粘贴!每一行代码都自己写。
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""
多头注意力机制
练习: 实现以下功能
1. 初始化Q, K, V的线性投影层
2. 实现分头逻辑(reshape + transpose)
3. 实现缩放点积注意力
4. 实现拼接和输出投影
输入: [batch_size, seq_len, d_model]
输出: [batch_size, seq_len, d_model]
"""
def __init__(self, d_model, num_heads):
super().__init__() # super()调用父类方法
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 练习: 初始化线性投影层
# self.W_Q = ...
# self.W_K = ...
# self.W_V = ...
# self.W_O = ...
raise NotImplementedError("请实现MultiHeadAttention的__init__方法")
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""
计算缩放点积注意力
Args:
Q: [batch, num_heads, seq_len, d_k]
K: [batch, num_heads, seq_len, d_k]
V: [batch, num_heads, seq_len, d_v]
mask: [batch, 1, seq_len, seq_len] 或 [batch, 1, 1, seq_len]
Returns:
output: [batch, num_heads, seq_len, d_v]
attention_weights: [batch, num_heads, seq_len, seq_len]
"""
# 练习: 实现缩放点积注意力
# 1. 计算Q @ K^T
# 2. 除以sqrt(d_k)
# 3. 应用mask(如果有)
# 4. softmax
# 5. 乘以V
raise NotImplementedError("请实现scaled_dot_product_attention方法")
def forward(self, query, key, value, mask=None):
"""
前向传播
Args:
query: [batch, seq_len, d_model]
key: [batch, seq_len, d_model]
value: [batch, seq_len, d_model]
mask: 可选的掩码
"""
batch_size = query.size(0)
# 练习: 实现多头注意力的前向传播
# 1. 线性投影
# 2. 分头(reshape + transpose)
# 3. 计算注意力
# 4. 拼接(transpose + reshape)
# 5. 输出投影
raise NotImplementedError("请实现forward方法")
class FeedForward(nn.Module):
"""
前馈神经网络
练习: 实现以下结构
Linear(d_model, d_ff) -> ReLU -> Linear(d_ff, d_model)
输入: [batch_size, seq_len, d_model]
输出: [batch_size, seq_len, d_model]
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
# 练习: 初始化两层线性变换
raise NotImplementedError("请实现FeedForward的__init__方法")
def forward(self, x):
# 练习: 实现前向传播
raise NotImplementedError("请实现FeedForward的forward方法")
class PositionalEncoding(nn.Module):
"""
正弦位置编码
练习: 实现正弦位置编码
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
"""
def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 练习: 计算位置编码矩阵
# pe shape: [max_seq_len, d_model]
raise NotImplementedError("请实现PositionalEncoding的__init__方法")
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
"""
# 练习: 将位置编码加到输入上
raise NotImplementedError("请实现PositionalEncoding的forward方法")
class EncoderLayer(nn.Module):
"""
单个Encoder层
练习: 实现以下结构
1. 多头自注意力
2. 残差连接 + LayerNorm
3. 前馈网络
4. 残差连接 + LayerNorm
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# 练习: 初始化组件
raise NotImplementedError("请实现EncoderLayer的__init__方法")
def forward(self, x, src_mask=None):
"""
Args:
x: [batch, seq_len, d_model]
src_mask: [batch, 1, 1, seq_len] 或 None
"""
# 练习: 实现Encoder层的前向传播
raise NotImplementedError("请实现EncoderLayer的forward方法")
class DecoderLayer(nn.Module):
"""
单个Decoder层
练习: 实现以下结构
1. 掩码多头自注意力
2. 残差连接 + LayerNorm
3. 交叉注意力(Encoder-Decoder Attention)
4. 残差连接 + LayerNorm
5. 前馈网络
6. 残差连接 + LayerNorm
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# 练习: 初始化组件
raise NotImplementedError("请实现DecoderLayer的__init__方法")
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
"""
Args:
x: [batch, tgt_seq_len, d_model]
enc_output: [batch, src_seq_len, d_model]
src_mask: [batch, 1, 1, src_seq_len]
tgt_mask: [batch, 1, tgt_seq_len, tgt_seq_len]
"""
# 练习: 实现Decoder层的前向传播
raise NotImplementedError("请实现DecoderLayer的forward方法")
class Transformer(nn.Module):
"""
完整的Transformer模型
练习: 组装所有组件
"""
def __init__(self, vocab_size, d_model=512, num_heads=8,
num_encoder_layers=6, num_decoder_layers=6,
d_ff=2048, max_seq_len=5000, dropout=0.1):
super().__init__()
self.d_model = d_model
# 练习: 初始化组件
# 1. Embedding层
# 2. 位置编码
# 3. Encoder层列表
# 4. Decoder层列表
# 5. 输出投影层
# 6. Dropout
raise NotImplementedError("请实现Transformer的__init__方法")
# 初始化参数
self._init_parameters()
def _init_parameters(self):
"""参数初始化"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def make_src_mask(self, src):
"""
创建源序列的padding mask
Args:
src: [batch, src_seq_len]
Returns:
mask: [batch, 1, 1, src_seq_len]
"""
# 练习: 实现src mask
raise NotImplementedError("请实现make_src_mask方法")
def make_tgt_mask(self, tgt):
"""
创建目标序列的mask(padding + causal)
Args:
tgt: [batch, tgt_seq_len]
Returns:
mask: [batch, 1, tgt_seq_len, tgt_seq_len]
"""
# 练习: 实现tgt mask
raise NotImplementedError("请实现make_tgt_mask方法")
def forward(self, src, tgt):
"""
训练前向传播
Args:
src: [batch, src_seq_len]
tgt: [batch, tgt_seq_len]
Returns:
output: [batch, tgt_seq_len, vocab_size]
"""
# 练习: 实现完整的前向传播
raise NotImplementedError("请实现forward方法")
def encode(self, src, src_mask=None):
"""编码源序列"""
# 练习: 实现编码
raise NotImplementedError("请实现encode方法")
def decode(self, tgt, enc_output, src_mask=None, tgt_mask=None):
"""解码生成"""
# 练习: 实现解码
raise NotImplementedError("请实现decode方法")
def generate(self, src, max_len=100, start_symbol=1, end_symbol=2):
"""
自回归生成
Args:
src: [batch, src_seq_len]
max_len: 最大生成长度
start_symbol: 开始符号的索引
end_symbol: 结束符号的索引
Returns:
outputs: [batch, generated_seq_len]
"""
# 练习: 实现自回归生成
raise NotImplementedError("请实现generate方法")
# 简单的GPT风格模型(只有Decoder)
class GPT(nn.Module):
"""
GPT风格的Decoder-only模型
用于语言建模任务,比完整的Encoder-Decoder更简单
"""
def __init__(self, vocab_size, d_model=512, num_heads=8,
num_layers=6, d_ff=2048, max_seq_len=5000, dropout=0.1):
super().__init__()
# 练习: 实现GPT模型(可选)
pass
训练脚本( train.py )¶
"""
训练脚本
任务:实现Transformer的训练流程
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
# 导入你实现的模块
from transformer import Transformer
from data import CharDataset, get_simple_data, create_dataloaders, CopyDataset
from config import Config, SmallConfig
def train_epoch(model, dataloader, optimizer, criterion, device):
"""
训练一个epoch
练习: 实现训练循环
"""
model.train()
total_loss = 0
# 练习: 实现训练循环
# 1. 遍历dataloader
# 2. 将数据移到device
# 3. 清零梯度
# 4. 前向传播
# 5. 计算损失
# 6. 反向传播
# 7. 更新参数
# 8. 累加损失
raise NotImplementedError("请实现train_epoch函数")
return total_loss / len(dataloader)
def evaluate(model, dataloader, criterion, device):
"""评估模型"""
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算,节省内存(推理时使用)
# 练习: 实现评估循环
pass
raise NotImplementedError("请实现evaluate函数")
return total_loss / len(dataloader), correct / total
def generate_text(model, dataset, start_text, max_length=100, device='cpu'):
"""
生成文本
练习: 实现文本生成
"""
model.eval()
# 练习: 实现文本生成
# 1. 编码start_text
# 2. 使用模型生成
# 3. 解码生成的序列
raise NotImplementedError("请实现generate_text函数")
def plot_losses(train_losses, val_losses, save_path='loss_curve.png'):
"""绘制损失曲线"""
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.savefig(save_path)
plt.close()
def main():
"""主训练函数"""
# 配置
config = SmallConfig()
device = torch.device(config.device)
print(f"使用设备: {device}")
# 准备数据
print("准备数据...")
text = get_simple_data()
train_loader, val_loader, dataset = create_dataloaders(
text,
seq_len=config.max_seq_len,
batch_size=config.batch_size
)
print(f"词表大小: {dataset.vocab_size}")
print(f"训练样本数: {len(train_loader.dataset)}")
print(f"验证样本数: {len(val_loader.dataset)}")
# 创建模型
print("创建模型...")
model = Transformer(
vocab_size=dataset.vocab_size,
d_model=config.d_model,
num_heads=config.num_heads,
num_encoder_layers=config.num_encoder_layers,
num_decoder_layers=config.num_decoder_layers,
d_ff=config.d_ff,
max_seq_len=config.max_seq_len,
dropout=config.dropout
).to(device) # .to(device)将数据移至GPU/CPU
# 打印模型参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"模型参数量: {total_params:,}")
# 优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss()
# 训练循环
print("\n开始训练...")
train_losses = []
val_losses = []
best_val_loss = float('inf')
for epoch in range(config.num_epochs):
# 练习: 实现完整的训练流程
# 1. 训练一个epoch
# 2. 评估
# 3. 记录损失
# 4. 保存最佳模型
# 5. 定期生成文本样例
pass # 删除这行,实现你的代码
# 保存最终模型
torch.save(model.state_dict(), 'transformer_final.pth')
print("\n训练完成!模型已保存")
# 绘制损失曲线
plot_losses(train_losses, val_losses)
print("损失曲线已保存到 loss_curve.png")
if __name__ == "__main__":
main()
最后更新日期: 2026-02-12 适用版本: LLM 学习教程 v2026