第 15 章:Rust AI 推理引擎¶
学习时间:6 小时 | 难度:⭐⭐⭐⭐⭐ 专家 | 前置知识:全部前置章节,特别是第 ⅜/11 章(所有权、智能指针、异步)
本章概览¶
Rust 在 AI 推理领域正在快速崛起:速度接近 C++、内存安全、无 GC 暂停、可编译为 WASM。本章覆盖 HuggingFace 的 Candle(Rust 原生 ML 框架)、Burn(深度学习训练框架)和基于 llama.cpp 的大模型推理。
学习目标:
- 使用 Candle 加载并运行 HuggingFace 模型(BERT, Mistral, Llama)
- 理解张量(Tensor)操作的 Rust API
- 掌握模型量化(INT8/INT4)减少内存占用
- 使用 Burn 框架定义自定义神经网络
- 编译 Rust AI 推理代码为 WASM,在浏览器运行
- 构建流式 LLM 推理服务(类似 llama.cpp server)
15.1 Rust AI 生态概览¶
| 框架 | 定位 | 特点 |
|---|---|---|
| Candle | 推理框架 | HuggingFace 出品,支持 CUDA/Metal,API 类似 PyTorch |
| Burn | 训练+推理 | 多后端(Libtorch/WGPU/NdArray),可训练自定义模型 |
| llama.cpp + Rust | LLM 推理 | 基于 C++ 核心,Rust 通过 FFI 调用,量化支持最完善 |
| ort (ONNX Runtime) | 生产推理 | 使用 ONNX Runtime C++ 库,支持几乎所有模型格式 |
| tch-rs | PyTorch 绑定 | LibTorch 的 Rust 绑定,API 与 Python PyTorch 最接近 |
15.2 Candle:HuggingFace Rust 推理框架¶
TOML
# Cargo.toml
[dependencies]
candle-core = { version = "0.8", features = ["cuda"] } # CUDA 支持
candle-nn = "0.8" # 神经网络层(线性层、激活函数等)
candle-transformers = "0.8" # HuggingFace 模型(BERT、GPT2、Llama 等)
tokenizers = "0.21" # HuggingFace 分词器
hf-hub = "0.3" # 自动从 HuggingFace Hub 下载模型
anyhow = "1"
tokio = { version = "1", features = ["full"] }
张量(Tensor)基础操作¶
Rust
use candle_core::{Device, DType, Tensor};
fn tensor_basics() -> anyhow::Result<()> {
// 选择计算设备:CPU 或 CUDA GPU
let device = Device::new_cuda(0).unwrap_or(Device::Cpu);
// let device = Device::new_metal(0)?; // macOS M1/M2/M3 Apple Silicon
// 创建张量
let a = Tensor::new(&[1.0f32, 2.0, 3.0], &device)?; // 1D 向量
let b = Tensor::new(&[[1.0f32, 2.0], [3.0, 4.0]], &device)?; // 2D 矩阵
println!("a shape: {:?}", a.shape()); // [3]
println!("b shape: {:?}", b.shape()); // [2, 2]
// 基础运算
let c = (&a + 1.0)?; // 标量加法:[2, 3, 4]
let d = (&a * &a)?; // 逐元素乘法:[1, 4, 9]
let sum = a.sum_all()?; // 求和:6.0
let mean = a.mean_all()?; // 均值:2.0
// 矩阵乘法(重要:DNN 的核心操作)
let w = Tensor::new(&[[1.0f32, 0.0], [0.0, 1.0]], &device)?; // 单位矩阵
let result = b.matmul(&w)?; // b @ w = b(不变)
// 形状变换
let flat = b.flatten_all()?; // [[1,2],[3,4]] → [1,2,3,4]
let reshaped = flat.reshape((2, 2))?; // 回到 [2, 2]
let transposed = b.t()?; // 转置:[[1,3],[2,4]]
// 索引与切片
let first_row = b.get(0)?; // 第一行:[1, 2]
let sliced = b.i((.., 0..1))?; // 所有行,第 0 列
// 类型转换
let int_tensor = a.to_dtype(DType::I64)?; // float32 → int64
Ok(())
}
15.3 BERT 文本嵌入¶
BERT 是最常用的文本理解模型,广泛用于语义搜索、文本分类等任务:
Rust
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use tokenizers::Tokenizer;
use hf_hub::{api::sync::Api, Repo, RepoType};
/// 使用 BERT 将文本编码为向量(用于语义搜索)
pub struct TextEmbedder {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
}
impl TextEmbedder {
/// 从 HuggingFace Hub 自动下载模型
pub fn from_pretrained(model_id: &str) -> anyhow::Result<Self> {
let device = Device::new_cuda(0).unwrap_or(Device::Cpu);
let api = Api::new()?; // 使用 HF_TOKEN 环境变量认证
// hf_hub 自动缓存模型到 ~/.cache/huggingface/
let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
let config_path = repo.get("config.json")?;
let tokenizer_path = repo.get("tokenizer.json")?;
let weights_path = repo.get("model.safetensors")?;
// 加载配置
let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(config_path)?)?;
// 加载分词器
let tokenizer = Tokenizer::from_file(tokenizer_path)?;
// 加载模型权重(safetensors 格式,零拷贝内存映射)
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)?
};
let model = BertModel::load(vb, &config)?;
Ok(Self { model, tokenizer, device })
}
/// 将一批文本编码为向量
pub fn encode_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
// 分词:文本 → token IDs
let encodings = self.tokenizer
.encode_batch(texts.to_vec(), true) // true = 添加特殊 token([CLS], [SEP])
.map_err(|e| anyhow::anyhow!("Tokenizer error: {}", e))?;
// 找最长序列长度(统一 padding)
let seq_len = encodings.iter().map(|e| e.len()).max().unwrap_or(0);
// 构建输入张量
let input_ids: Vec<u32> = encodings.iter()
.flat_map(|enc| {
let mut ids = enc.get_ids().to_vec();
ids.resize(seq_len, 0); // padding with 0
ids
})
.collect();
let attention_mask: Vec<u32> = encodings.iter()
.flat_map(|enc| {
let mut mask = enc.get_attention_mask().to_vec();
mask.resize(seq_len, 0);
mask
})
.collect();
let batch_size = texts.len();
// 转为 Candle 张量(形状:[batch_size, seq_len])
let input_ids_tensor = Tensor::from_vec(input_ids, (batch_size, seq_len), &self.device)?;
let attention_mask_tensor = Tensor::from_vec(attention_mask, (batch_size, seq_len), &self.device)?;
// 前向传播
let outputs = self.model.forward(
&input_ids_tensor,
&attention_mask_tensor,
None, // token_type_ids(段落 ID,BERT 特有,这里不需要)
)?;
// 取 [CLS] token 的输出作为句子向量(第一个 token)
// outputs shape: [batch_size, seq_len, hidden_size]
let cls_embeddings = outputs.i((.., 0, ..))?; // [batch_size, hidden_size]
// L2 归一化(余弦相似度计算的前提)
let norm = cls_embeddings.sqr()?.sum(1)?.sqrt()?.unsqueeze(1)?;
let normalized = (&cls_embeddings / &norm)?;
// 转换为 Rust Vec
let embeddings: Vec<Vec<f32>> = (0..batch_size)
.map(|i| {
normalized.get(i).unwrap().to_vec1::<f32>().unwrap()
})
.collect();
Ok(embeddings)
}
/// 计算两段文本的语义相似度(余弦相似度)
pub fn similarity(&self, text1: &str, text2: &str) -> anyhow::Result<f32> {
let embeddings = self.encode_batch(&[text1, text2])?;
let e1 = &embeddings[0];
let e2 = &embeddings[1];
// 余弦相似度 = e1 · e2(因为已经归一化,直接点积)
let dot_product: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
Ok(dot_product)
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// 使用多语言 BERT(mBERT),支持中文
let embedder = TextEmbedder::from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")?;
// 语义搜索示例
let query = "如何学习 Rust?";
let docs = [
"Rust 是一门系统编程语言,注重安全性和性能",
"Python 是一门动态类型的脚本语言",
"《Rust 程序设计语言》是学习 Rust 的最佳书籍",
"JavaScript 是 Web 开发的核心语言",
];
let query_embedding = embedder.encode_batch(&[query])?;
let doc_embeddings = embedder.encode_batch(&docs)?;
// 找出最相关的文档
let mut similarities: Vec<(f32, &str)> = doc_embeddings.iter()
.zip(docs.iter())
.map(|(emb, doc)| {
let score: f32 = query_embedding[0].iter()
.zip(emb.iter())
.map(|(a, b)| a * b)
.sum();
(score, *doc)
})
.collect();
similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
println!("查询:{}", query);
println!("相关文档:");
for (score, doc) in &similarities {
println!(" [{:.3}] {}", score, doc);
}
Ok(())
}
15.4 Llama 大模型推理与量化¶
Rust
// Cargo.toml 追加
// candle-transformers 已经包含 Llama 模型实现
use candle_core::{DType, Device, Tensor};
use candle_transformers::models::llama::{Cache, Config as LlamaConfig, Llama};
use candle_transformers::generation::LogitsProcessor;
use tokenizers::Tokenizer;
/// Llama/Mistral 大模型推理器
pub struct LlamaInference {
model: Llama,
tokenizer: Tokenizer,
cache: Cache, // KV Cache(加速生成)
device: Device,
config: LlamaConfig,
}
impl LlamaInference {
pub fn from_pretrained(model_id: &str, use_quantized: bool) -> anyhow::Result<Self> {
let device = Device::new_cuda(0).unwrap_or(Device::Cpu);
let api = hf_hub::api::sync::Api::new()?;
let repo = api.repo(hf_hub::Repo::new(model_id.to_string(), hf_hub::RepoType::Model));
let config_path = repo.get("config.json")?;
let tokenizer_path = repo.get("tokenizer.json")?;
let config: LlamaConfig = serde_json::from_str(&std::fs::read_to_string(config_path)?)?;
let tokenizer = Tokenizer::from_file(tokenizer_path)?;
// 量化模型使用 GGUF 格式(INT4/INT8),大幅降低内存
// 非量化:Mistral-7B 需要 ~14GB VRAM (FP16)
// INT4 量化:Mistral-7B 只需 ~4GB VRAM
let dtype = if use_quantized { DType::F16 } else { DType::F32 };
let weights = if use_quantized {
// 使用 Q4_K_M 量化版本
repo.get("model.gguf")?
} else {
repo.get("model.safetensors")?
};
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)?
};
let cache = Cache::new(true, dtype, &config, &device)?; // true = use_kv_cache
let model = Llama::load(vb, &config)?;
Ok(Self { model, tokenizer, cache, device, config })
}
/// 流式文本生成
pub fn generate_stream(
&mut self,
prompt: &str,
max_tokens: usize,
temperature: f64,
) -> anyhow::Result<String> {
// 将 prompt 编码为 token IDs
let tokens = self.tokenizer
.encode(prompt, true)
.map_err(|e| anyhow::anyhow!("Tokenize error: {}", e))?;
let mut token_ids: Vec<u32> = tokens.get_ids().to_vec();
// 采样器:控制生成的随机性
let mut logits_processor = LogitsProcessor::new(
42, // 随机种子(固定 seed 使结果可复现)
Some(temperature), // temperature:越高越随机(0.7 是常用值)
Some(0.9), // top_p:核采样,过滤低概率 token
);
let mut generated_text = String::new();
let mut pos = 0; // 当前位置(用于 KV Cache)
for _ in 0..max_tokens {
// 准备输入张量
let input = Tensor::from_vec(token_ids.clone(), (1, token_ids.len()), &self.device)?;
// 前向传播
let logits = self.model.forward(&input, pos, &mut self.cache)?;
// logits shape: [1, seq_len, vocab_size]
// 只取最后一个 token 的 logits
let logits = logits.squeeze(0)?; // [seq_len, vocab_size]
let last_logits = logits.get(logits.dims()[0] - 1)?; // [vocab_size]
// 采样下一个 token
let next_token = logits_processor.sample(&last_logits)?;
// 检查是否生成了 EOS(结束)token
if next_token == self.config.eos_token_id.unwrap_or(2) {
break;
}
// 将 token ID 解码为文字
let token_str = self.tokenizer
.decode(&[next_token], true)
.map_err(|e| anyhow::anyhow!("Decode error: {}", e))?;
// 流式输出(实际应用中发送给 Channel 或 WebSocket)
print!("{}", token_str);
std::io::Write::flush(&mut std::io::stdout())?;
generated_text.push_str(&token_str);
pos += token_ids.len();
token_ids = vec![next_token]; // 增量推理:只传新 token(利用 KV Cache)
}
println!();
Ok(generated_text)
}
}
15.5 Burn:训练自定义神经网络¶
Burn 是一个纯 Rust 的深度学习框架,支持多种后端(CPU/GPU)和训练功能:
Rust
use burn::{
module::Module,
nn::{Linear, LinearConfig, Relu, Dropout, DropoutConfig},
tensor::{backend::Backend, Tensor},
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};
/// 文本分类神经网络(用于情感分析)
#[derive(Module, Debug)]
pub struct TextClassifier<B: Backend> {
fc1: Linear<B>, // 全连接层 1
fc2: Linear<B>, // 全连接层 2
output: Linear<B>, // 输出层
dropout: Dropout, // Dropout 防止过拟合
relu: Relu, // 激活函数
}
impl<B: Backend> TextClassifier<B> {
/// 创建模型(hidden_size 通常是 BERT 的输出维度 768)
pub fn new(device: &B::Device, hidden_size: usize, num_classes: usize) -> Self {
Self {
fc1: LinearConfig::new(hidden_size, 256).init(device),
fc2: LinearConfig::new(256, 64).init(device),
output: LinearConfig::new(64, num_classes).init(device),
dropout: DropoutConfig::new(0.3).init(),
relu: Relu::new(),
}
}
/// 前向传播
/// input: [batch_size, hidden_size] (BERT 输出的 CLS embedding)
/// 返回:[batch_size, num_classes] (每个类别的 logits)
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
// input → 256 → ReLU → Dropout → 64 → ReLU → num_classes
let x = self.fc1.forward(input); // [batch, 256]
let x = self.relu.forward(x);
let x = self.dropout.forward(x);
let x = self.fc2.forward(x); // [batch, 64]
let x = self.relu.forward(x);
let x = self.dropout.forward(x);
self.output.forward(x) // [batch, num_classes]
}
}
// 实现训练步骤
impl<B: burn::tensor::backend::AutodiffBackend> TrainStep<Batch<B>, ClassificationOutput<B>>
for TextClassifier<B>
{
fn step(&self, batch: Batch<B>) -> TrainOutput<ClassificationOutput<B>> {
let output = self.forward(batch.embeddings).softmax(1); // softmax 归一化
let loss = CrossEntropyLoss::new().forward(output.clone(), batch.labels.clone());
TrainOutput::new(self, loss.backward(), ClassificationOutput { loss, output, targets: batch.labels })
}
}
struct Batch<B: Backend> {
embeddings: Tensor<B, 2>, // [batch_size, hidden_size]
labels: Tensor<B, 1, Int>, // [batch_size]
}
15.6 WASM:在浏览器运行 Rust AI¶
将 Rust AI 推理编译为 WebAssembly,实现客户端 AI 推理(隐私保护,无需后端):
Bash
cargo add wasm-bindgen js-sys web-sys
cargo install wasm-pack
# 使用 candle 的 wasm feature(不需要 CUDA,用 WebGL 后端)
Rust
// src/lib.rs — WASM 导出
use wasm_bindgen::prelude::*;
use candle_core::{Device, Tensor};
#[wasm_bindgen]
pub struct WasmEmbedder {
// ... 在 WASM 中使用 CPU 后端(WebGL 支持实验中)
}
#[wasm_bindgen]
impl WasmEmbedder {
#[wasm_bindgen(constructor)]
pub async fn new() -> Result<WasmEmbedder, JsValue> {
// 从 fetch 加载模型权重(来自 CDN 或 IndexedDB 缓存)
todo!()
}
/// 文本嵌入(从 JS 调用)
#[wasm_bindgen]
pub fn embed(&self, text: &str) -> Vec<f32> {
// ... 返回 Float32Array 给 JavaScript
vec![]
}
}
// JavaScript 使用:
// import init, { WasmEmbedder } from './pkg/rust_ai.js';
// await init();
// const embedder = await new WasmEmbedder();
// const vec = embedder.embed("Hello world");
Bash
# 构建 WASM 包
wasm-pack build --target web --out-dir pkg
# 输出文件:
# pkg/rust_ai.js (JS 胶水代码)
# pkg/rust_ai_bg.wasm (WebAssembly 二进制)
# pkg/rust_ai.d.ts (TypeScript 类型声明!)
15.7 实战:构建流式 LLM HTTP 服务¶
Rust
// src/main.rs — 用 Axum + Candle 构建 llama.cpp server 的 Rust 版本
use axum::{
extract::State,
response::{IntoResponse, Response},
routing::post,
Json, Router,
};
use tokio::sync::Mutex;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Clone)]
struct AppState {
model: Arc<Mutex<LlamaInference>>, // 模型需要互斥锁(推理是有状态的)
}
#[derive(Deserialize)]
struct CompletionRequest {
prompt: String,
max_tokens: usize,
temperature: f64,
stream: bool,
}
#[derive(Serialize)]
struct CompletionChunk {
text: String,
finish_reason: Option<String>,
tokens_generated: usize,
}
async fn completions(
State(state): State<AppState>,
Json(req): Json<CompletionRequest>,
) -> Response {
if req.stream {
// 流式响应(Server-Sent Events)
let stream = async_stream::stream! {
let mut model = state.model.lock().await;
// ... 循环生成 token,每个 token 作为一个 SSE 事件发送
yield CompletionChunk { text: "Hello".to_string(), finish_reason: None, tokens_generated: 1 };
};
// 转换为 SSE 响应
axum_sse_manager::Sse::new(stream).into_response()
} else {
// 非流式:等待全部生成完成
let mut model = state.model.lock().await;
let text = model.generate_stream(&req.prompt, req.max_tokens, req.temperature)
.unwrap_or_default();
Json(CompletionChunk { text, finish_reason: Some("stop".to_string()), tokens_generated: 0 }).into_response()
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let model = LlamaInference::from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.3",
true, // 使用量化版本
)?;
let state = AppState {
model: Arc::new(Mutex::new(model)),
};
let app = Router::new()
.route("/v1/completions", post(completions))
.with_state(state);
let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await?;
println!("LLM Server running on http://localhost:8080");
axum::serve(listener, app).await?;
Ok(())
}
15.8 性能基准与优化¶
| 模型 | 后端 | 量化 | 速度(tokens/s) | 内存 |
|---|---|---|---|---|
| Mistral-7B | CPU(16核) | INT4 | ~8 t/s | 4GB |
| Mistral-7B | CUDA(RTX 3090) | FP16 | ~90 t/s | 14GB |
| Mistral-7B | CUDA(RTX 3090) | INT4 | ~150 t/s | 5GB |
| BERT-base | CPU | FP32 | 200 batch/s | 0.4GB |
| BERT-base | CUDA | FP32 | 2000 batch/s | 0.5GB |
关键优化技巧¶
Rust
// 1. 批处理推理(比逐条推理快 10-50 倍)
let embeddings = embedder.encode_batch(&texts)?; // 一次处理 32 条
// 2. Flash Attention(Candle 支持,减少显存,加速长序列)
// 在 candle-transformers 的 Mistral/Llama 模型中默认启用
// 3. 模型并行(多 GPU 推理)
// 需要 candle-core 的 cuda_async feature
// 4. 预热(第一次推理加载模型到 GPU,之后快很多)
let _ = model.generate_stream("预热", 1, 0.1)?;
// 5. 使用 safetensors 内存映射(零拷贝加载,节省内存)
// unsafe { VarBuilder::from_mmaped_safetensors(...) }
🏋️ 本章练习¶
- 基础:使用
sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2模型,对中文文档建立向量索引,实现语义搜索 - 进阶:用 Burn 训练一个简单的情感分类器(正面/负面),在中文影评数据集上微调
- 挑战:将 BERT 文本嵌入编译为 WASM,在浏览器中实现本地语义搜索(隐私保护)
📌 本章小结¶
| 框架 | 适用场景 |
|---|---|
| Candle | 生产推理,支持 HF 模型和量化,API 简洁 |
| Burn | 需要训练自定义模型,多后端支持 |
| ort (ONNX) | 跨平台部署(移动/嵌入式),ONNX 生态 |
| wasm-pack | 浏览器端 AI,隐私敏感场景 |
Candle 0.8 · Burn 0.14 · Rust 2024 Edition · 2025