跳转至

第 15 章:Rust AI 推理引擎

学习时间:6 小时 | 难度:⭐⭐⭐⭐⭐ 专家 | 前置知识:全部前置章节,特别是第 ⅜/11 章(所有权、智能指针、异步)


本章概览

Rust 在 AI 推理领域正在快速崛起:速度接近 C++、内存安全、无 GC 暂停、可编译为 WASM。本章覆盖 HuggingFace 的 Candle(Rust 原生 ML 框架)、Burn(深度学习训练框架)和基于 llama.cpp 的大模型推理。

学习目标:

  • 使用 Candle 加载并运行 HuggingFace 模型(BERT, Mistral, Llama)
  • 理解张量(Tensor)操作的 Rust API
  • 掌握模型量化(INT8/INT4)减少内存占用
  • 使用 Burn 框架定义自定义神经网络
  • 编译 Rust AI 推理代码为 WASM,在浏览器运行
  • 构建流式 LLM 推理服务(类似 llama.cpp server)

15.1 Rust AI 生态概览

框架 定位 特点
Candle 推理框架 HuggingFace 出品,支持 CUDA/Metal,API 类似 PyTorch
Burn 训练+推理 多后端(Libtorch/WGPU/NdArray),可训练自定义模型
llama.cpp + Rust LLM 推理 基于 C++ 核心,Rust 通过 FFI 调用,量化支持最完善
ort (ONNX Runtime) 生产推理 使用 ONNX Runtime C++ 库,支持几乎所有模型格式
tch-rs PyTorch 绑定 LibTorch 的 Rust 绑定,API 与 Python PyTorch 最接近

15.2 Candle:HuggingFace Rust 推理框架

TOML
# Cargo.toml
[dependencies]
candle-core   = { version = "0.8", features = ["cuda"] }  # CUDA 支持
candle-nn     = "0.8"          # 神经网络层(线性层、激活函数等)
candle-transformers = "0.8"   # HuggingFace 模型(BERT、GPT2、Llama 等)
tokenizers    = "0.21"        # HuggingFace 分词器
hf-hub        = "0.3"        # 自动从 HuggingFace Hub 下载模型
anyhow        = "1"
tokio         = { version = "1", features = ["full"] }

张量(Tensor)基础操作

Rust
use candle_core::{Device, DType, Tensor};

fn tensor_basics() -> anyhow::Result<()> {
    // 选择计算设备:CPU 或 CUDA GPU
    let device = Device::new_cuda(0).unwrap_or(Device::Cpu);
    // let device = Device::new_metal(0)?; // macOS M1/M2/M3 Apple Silicon

    // 创建张量
    let a = Tensor::new(&[1.0f32, 2.0, 3.0], &device)?;  // 1D 向量
    let b = Tensor::new(&[[1.0f32, 2.0], [3.0, 4.0]], &device)?; // 2D 矩阵

    println!("a shape: {:?}", a.shape());  // [3]
    println!("b shape: {:?}", b.shape());  // [2, 2]

    // 基础运算
    let c = (&a + 1.0)?;         // 标量加法:[2, 3, 4]
    let d = (&a * &a)?;          // 逐元素乘法:[1, 4, 9]
    let sum = a.sum_all()?;      // 求和:6.0
    let mean = a.mean_all()?;    // 均值:2.0

    // 矩阵乘法(重要:DNN 的核心操作)
    let w = Tensor::new(&[[1.0f32, 0.0], [0.0, 1.0]], &device)?; // 单位矩阵
    let result = b.matmul(&w)?;  // b @ w = b(不变)

    // 形状变换
    let flat = b.flatten_all()?;    // [[1,2],[3,4]] → [1,2,3,4]
    let reshaped = flat.reshape((2, 2))?;  // 回到 [2, 2]
    let transposed = b.t()?;        // 转置:[[1,3],[2,4]]

    // 索引与切片
    let first_row = b.get(0)?;      // 第一行:[1, 2]
    let sliced = b.i((.., 0..1))?;  // 所有行,第 0 列

    // 类型转换
    let int_tensor = a.to_dtype(DType::I64)?;  // float32 → int64

    Ok(())
}

15.3 BERT 文本嵌入

BERT 是最常用的文本理解模型,广泛用于语义搜索、文本分类等任务:

Rust
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use tokenizers::Tokenizer;
use hf_hub::{api::sync::Api, Repo, RepoType};

/// 使用 BERT 将文本编码为向量(用于语义搜索)
pub struct TextEmbedder {
    model:     BertModel,
    tokenizer: Tokenizer,
    device:    Device,
}

impl TextEmbedder {
    /// 从 HuggingFace Hub 自动下载模型
    pub fn from_pretrained(model_id: &str) -> anyhow::Result<Self> {
        let device = Device::new_cuda(0).unwrap_or(Device::Cpu);
        let api = Api::new()?;  // 使用 HF_TOKEN 环境变量认证

        // hf_hub 自动缓存模型到 ~/.cache/huggingface/
        let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
        let config_path    = repo.get("config.json")?;
        let tokenizer_path = repo.get("tokenizer.json")?;
        let weights_path   = repo.get("model.safetensors")?;

        // 加载配置
        let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(config_path)?)?;

        // 加载分词器
        let tokenizer = Tokenizer::from_file(tokenizer_path)?;

        // 加载模型权重(safetensors 格式,零拷贝内存映射)
        let vb = unsafe {
            VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)?
        };
        let model = BertModel::load(vb, &config)?;

        Ok(Self { model, tokenizer, device })
    }

    /// 将一批文本编码为向量
    pub fn encode_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
        // 分词:文本 → token IDs
        let encodings = self.tokenizer
            .encode_batch(texts.to_vec(), true)  // true = 添加特殊 token([CLS], [SEP])
            .map_err(|e| anyhow::anyhow!("Tokenizer error: {}", e))?;

        // 找最长序列长度(统一 padding)
        let seq_len = encodings.iter().map(|e| e.len()).max().unwrap_or(0);

        // 构建输入张量
        let input_ids: Vec<u32> = encodings.iter()
            .flat_map(|enc| {
                let mut ids = enc.get_ids().to_vec();
                ids.resize(seq_len, 0); // padding with 0
                ids
            })
            .collect();

        let attention_mask: Vec<u32> = encodings.iter()
            .flat_map(|enc| {
                let mut mask = enc.get_attention_mask().to_vec();
                mask.resize(seq_len, 0);
                mask
            })
            .collect();

        let batch_size = texts.len();

        // 转为 Candle 张量(形状:[batch_size, seq_len])
        let input_ids_tensor = Tensor::from_vec(input_ids, (batch_size, seq_len), &self.device)?;
        let attention_mask_tensor = Tensor::from_vec(attention_mask, (batch_size, seq_len), &self.device)?;

        // 前向传播
        let outputs = self.model.forward(
            &input_ids_tensor,
            &attention_mask_tensor,
            None, // token_type_ids(段落 ID,BERT 特有,这里不需要)
        )?;

        // 取 [CLS] token 的输出作为句子向量(第一个 token)
        // outputs shape: [batch_size, seq_len, hidden_size]
        let cls_embeddings = outputs.i((.., 0, ..))?;  // [batch_size, hidden_size]

        // L2 归一化(余弦相似度计算的前提)
        let norm = cls_embeddings.sqr()?.sum(1)?.sqrt()?.unsqueeze(1)?;
        let normalized = (&cls_embeddings / &norm)?;

        // 转换为 Rust Vec
        let embeddings: Vec<Vec<f32>> = (0..batch_size)
            .map(|i| {
                normalized.get(i).unwrap().to_vec1::<f32>().unwrap()
            })
            .collect();

        Ok(embeddings)
    }

    /// 计算两段文本的语义相似度(余弦相似度)
    pub fn similarity(&self, text1: &str, text2: &str) -> anyhow::Result<f32> {
        let embeddings = self.encode_batch(&[text1, text2])?;
        let e1 = &embeddings[0];
        let e2 = &embeddings[1];

        // 余弦相似度 = e1 · e2(因为已经归一化,直接点积)
        let dot_product: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
        Ok(dot_product)
    }
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    // 使用多语言 BERT(mBERT),支持中文
    let embedder = TextEmbedder::from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")?;

    // 语义搜索示例
    let query = "如何学习 Rust?";
    let docs = [
        "Rust 是一门系统编程语言,注重安全性和性能",
        "Python 是一门动态类型的脚本语言",
        "《Rust 程序设计语言》是学习 Rust 的最佳书籍",
        "JavaScript 是 Web 开发的核心语言",
    ];

    let query_embedding = embedder.encode_batch(&[query])?;
    let doc_embeddings = embedder.encode_batch(&docs)?;

    // 找出最相关的文档
    let mut similarities: Vec<(f32, &str)> = doc_embeddings.iter()
        .zip(docs.iter())
        .map(|(emb, doc)| {
            let score: f32 = query_embedding[0].iter()
                .zip(emb.iter())
                .map(|(a, b)| a * b)
                .sum();
            (score, *doc)
        })
        .collect();

    similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());

    println!("查询:{}", query);
    println!("相关文档:");
    for (score, doc) in &similarities {
        println!("  [{:.3}] {}", score, doc);
    }

    Ok(())
}

15.4 Llama 大模型推理与量化

Rust
// Cargo.toml 追加
// candle-transformers 已经包含 Llama 模型实现

use candle_core::{DType, Device, Tensor};
use candle_transformers::models::llama::{Cache, Config as LlamaConfig, Llama};
use candle_transformers::generation::LogitsProcessor;
use tokenizers::Tokenizer;

/// Llama/Mistral 大模型推理器
pub struct LlamaInference {
    model:     Llama,
    tokenizer: Tokenizer,
    cache:     Cache,          // KV Cache(加速生成)
    device:    Device,
    config:    LlamaConfig,
}

impl LlamaInference {
    pub fn from_pretrained(model_id: &str, use_quantized: bool) -> anyhow::Result<Self> {
        let device = Device::new_cuda(0).unwrap_or(Device::Cpu);
        let api = hf_hub::api::sync::Api::new()?;
        let repo = api.repo(hf_hub::Repo::new(model_id.to_string(), hf_hub::RepoType::Model));

        let config_path    = repo.get("config.json")?;
        let tokenizer_path = repo.get("tokenizer.json")?;

        let config: LlamaConfig = serde_json::from_str(&std::fs::read_to_string(config_path)?)?;
        let tokenizer = Tokenizer::from_file(tokenizer_path)?;

        // 量化模型使用 GGUF 格式(INT4/INT8),大幅降低内存
        // 非量化:Mistral-7B 需要 ~14GB VRAM (FP16)
        // INT4 量化:Mistral-7B 只需 ~4GB VRAM
        let dtype = if use_quantized { DType::F16 } else { DType::F32 };

        let weights = if use_quantized {
            // 使用 Q4_K_M 量化版本
            repo.get("model.gguf")?
        } else {
            repo.get("model.safetensors")?
        };

        let vb = unsafe {
            candle_nn::VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)?
        };

        let cache = Cache::new(true, dtype, &config, &device)?; // true = use_kv_cache
        let model = Llama::load(vb, &config)?;

        Ok(Self { model, tokenizer, cache, device, config })
    }

    /// 流式文本生成
    pub fn generate_stream(
        &mut self,
        prompt: &str,
        max_tokens: usize,
        temperature: f64,
    ) -> anyhow::Result<String> {
        // 将 prompt 编码为 token IDs
        let tokens = self.tokenizer
            .encode(prompt, true)
            .map_err(|e| anyhow::anyhow!("Tokenize error: {}", e))?;
        let mut token_ids: Vec<u32> = tokens.get_ids().to_vec();

        // 采样器:控制生成的随机性
        let mut logits_processor = LogitsProcessor::new(
            42,                          // 随机种子(固定 seed 使结果可复现)
            Some(temperature),           // temperature:越高越随机(0.7 是常用值)
            Some(0.9),                   // top_p:核采样,过滤低概率 token
        );

        let mut generated_text = String::new();
        let mut pos = 0; // 当前位置(用于 KV Cache)

        for _ in 0..max_tokens {
            // 准备输入张量
            let input = Tensor::from_vec(token_ids.clone(), (1, token_ids.len()), &self.device)?;

            // 前向传播
            let logits = self.model.forward(&input, pos, &mut self.cache)?;
            // logits shape: [1, seq_len, vocab_size]

            // 只取最后一个 token 的 logits
            let logits = logits.squeeze(0)?;           // [seq_len, vocab_size]
            let last_logits = logits.get(logits.dims()[0] - 1)?; // [vocab_size]

            // 采样下一个 token
            let next_token = logits_processor.sample(&last_logits)?;

            // 检查是否生成了 EOS(结束)token
            if next_token == self.config.eos_token_id.unwrap_or(2) {
                break;
            }

            // 将 token ID 解码为文字
            let token_str = self.tokenizer
                .decode(&[next_token], true)
                .map_err(|e| anyhow::anyhow!("Decode error: {}", e))?;

            // 流式输出(实际应用中发送给 Channel 或 WebSocket)
            print!("{}", token_str);
            std::io::Write::flush(&mut std::io::stdout())?;

            generated_text.push_str(&token_str);
            pos += token_ids.len();
            token_ids = vec![next_token]; // 增量推理:只传新 token(利用 KV Cache)
        }
        println!();

        Ok(generated_text)
    }
}

15.5 Burn:训练自定义神经网络

Burn 是一个纯 Rust 的深度学习框架,支持多种后端(CPU/GPU)和训练功能:

TOML
[dependencies]
burn = { version = "0.14", features = ["wgpu", "train", "vision"] }
Rust
use burn::{
    module::Module,
    nn::{Linear, LinearConfig, Relu, Dropout, DropoutConfig},
    tensor::{backend::Backend, Tensor},
    train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};

/// 文本分类神经网络(用于情感分析)
#[derive(Module, Debug)]
pub struct TextClassifier<B: Backend> {
    fc1:     Linear<B>,     // 全连接层 1
    fc2:     Linear<B>,     // 全连接层 2
    output:  Linear<B>,     // 输出层
    dropout: Dropout,       // Dropout 防止过拟合
    relu:    Relu,          // 激活函数
}

impl<B: Backend> TextClassifier<B> {
    /// 创建模型(hidden_size 通常是 BERT 的输出维度 768)
    pub fn new(device: &B::Device, hidden_size: usize, num_classes: usize) -> Self {
        Self {
            fc1:     LinearConfig::new(hidden_size, 256).init(device),
            fc2:     LinearConfig::new(256, 64).init(device),
            output:  LinearConfig::new(64, num_classes).init(device),
            dropout: DropoutConfig::new(0.3).init(),
            relu:    Relu::new(),
        }
    }

    /// 前向传播
    /// input: [batch_size, hidden_size] (BERT 输出的 CLS embedding)
    /// 返回:[batch_size, num_classes] (每个类别的 logits)
    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
        // input → 256 → ReLU → Dropout → 64 → ReLU → num_classes
        let x = self.fc1.forward(input);        // [batch, 256]
        let x = self.relu.forward(x);
        let x = self.dropout.forward(x);

        let x = self.fc2.forward(x);            // [batch, 64]
        let x = self.relu.forward(x);
        let x = self.dropout.forward(x);

        self.output.forward(x)                  // [batch, num_classes]
    }
}

// 实现训练步骤
impl<B: burn::tensor::backend::AutodiffBackend> TrainStep<Batch<B>, ClassificationOutput<B>>
    for TextClassifier<B>
{
    fn step(&self, batch: Batch<B>) -> TrainOutput<ClassificationOutput<B>> {
        let output = self.forward(batch.embeddings).softmax(1); // softmax 归一化
        let loss = CrossEntropyLoss::new().forward(output.clone(), batch.labels.clone());

        TrainOutput::new(self, loss.backward(), ClassificationOutput { loss, output, targets: batch.labels })
    }
}

struct Batch<B: Backend> {
    embeddings: Tensor<B, 2>, // [batch_size, hidden_size]
    labels:     Tensor<B, 1, Int>, // [batch_size]
}

15.6 WASM:在浏览器运行 Rust AI

将 Rust AI 推理编译为 WebAssembly,实现客户端 AI 推理(隐私保护,无需后端):

Bash
cargo add wasm-bindgen js-sys web-sys
cargo install wasm-pack

# 使用 candle 的 wasm feature(不需要 CUDA,用 WebGL 后端)
Rust
// src/lib.rs — WASM 导出
use wasm_bindgen::prelude::*;
use candle_core::{Device, Tensor};

#[wasm_bindgen]
pub struct WasmEmbedder {
    // ... 在 WASM 中使用 CPU 后端(WebGL 支持实验中)
}

#[wasm_bindgen]
impl WasmEmbedder {
    #[wasm_bindgen(constructor)]
    pub async fn new() -> Result<WasmEmbedder, JsValue> {
        // 从 fetch 加载模型权重(来自 CDN 或 IndexedDB 缓存)
        todo!()
    }

    /// 文本嵌入(从 JS 调用)
    #[wasm_bindgen]
    pub fn embed(&self, text: &str) -> Vec<f32> {
        // ... 返回 Float32Array 给 JavaScript
        vec![]
    }
}

// JavaScript 使用:
// import init, { WasmEmbedder } from './pkg/rust_ai.js';
// await init();
// const embedder = await new WasmEmbedder();
// const vec = embedder.embed("Hello world");
Bash
# 构建 WASM 包
wasm-pack build --target web --out-dir pkg

# 输出文件:
# pkg/rust_ai.js       (JS 胶水代码)
# pkg/rust_ai_bg.wasm  (WebAssembly 二进制)
# pkg/rust_ai.d.ts     (TypeScript 类型声明!)

15.7 实战:构建流式 LLM HTTP 服务

Rust
// src/main.rs — 用 Axum + Candle 构建 llama.cpp server 的 Rust 版本
use axum::{
    extract::State,
    response::{IntoResponse, Response},
    routing::post,
    Json, Router,
};
use tokio::sync::Mutex;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

#[derive(Clone)]
struct AppState {
    model: Arc<Mutex<LlamaInference>>, // 模型需要互斥锁(推理是有状态的)
}

#[derive(Deserialize)]
struct CompletionRequest {
    prompt:      String,
    max_tokens:  usize,
    temperature: f64,
    stream:      bool,
}

#[derive(Serialize)]
struct CompletionChunk {
    text:             String,
    finish_reason:    Option<String>,
    tokens_generated: usize,
}

async fn completions(
    State(state): State<AppState>,
    Json(req): Json<CompletionRequest>,
) -> Response {
    if req.stream {
        // 流式响应(Server-Sent Events)
        let stream = async_stream::stream! {
            let mut model = state.model.lock().await;
            // ... 循环生成 token,每个 token 作为一个 SSE 事件发送
            yield CompletionChunk { text: "Hello".to_string(), finish_reason: None, tokens_generated: 1 };
        };

        // 转换为 SSE 响应
        axum_sse_manager::Sse::new(stream).into_response()
    } else {
        // 非流式:等待全部生成完成
        let mut model = state.model.lock().await;
        let text = model.generate_stream(&req.prompt, req.max_tokens, req.temperature)
            .unwrap_or_default();
        Json(CompletionChunk { text, finish_reason: Some("stop".to_string()), tokens_generated: 0 }).into_response()
    }
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    let model = LlamaInference::from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.3",
        true, // 使用量化版本
    )?;

    let state = AppState {
        model: Arc::new(Mutex::new(model)),
    };

    let app = Router::new()
        .route("/v1/completions", post(completions))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await?;
    println!("LLM Server running on http://localhost:8080");
    axum::serve(listener, app).await?;

    Ok(())
}

15.8 性能基准与优化

模型 后端 量化 速度(tokens/s) 内存
Mistral-7B CPU(16核) INT4 ~8 t/s 4GB
Mistral-7B CUDA(RTX 3090) FP16 ~90 t/s 14GB
Mistral-7B CUDA(RTX 3090) INT4 ~150 t/s 5GB
BERT-base CPU FP32 200 batch/s 0.4GB
BERT-base CUDA FP32 2000 batch/s 0.5GB

关键优化技巧

Rust
// 1. 批处理推理(比逐条推理快 10-50 倍)
let embeddings = embedder.encode_batch(&texts)?; // 一次处理 32 条

// 2. Flash Attention(Candle 支持,减少显存,加速长序列)
// 在 candle-transformers 的 Mistral/Llama 模型中默认启用

// 3. 模型并行(多 GPU 推理)
// 需要 candle-core 的 cuda_async feature

// 4. 预热(第一次推理加载模型到 GPU,之后快很多)
let _ = model.generate_stream("预热", 1, 0.1)?;

// 5. 使用 safetensors 内存映射(零拷贝加载,节省内存)
// unsafe { VarBuilder::from_mmaped_safetensors(...) }

🏋️ 本章练习

  1. 基础:使用 sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 模型,对中文文档建立向量索引,实现语义搜索
  2. 进阶:用 Burn 训练一个简单的情感分类器(正面/负面),在中文影评数据集上微调
  3. 挑战:将 BERT 文本嵌入编译为 WASM,在浏览器中实现本地语义搜索(隐私保护)

📌 本章小结

框架 适用场景
Candle 生产推理,支持 HF 模型和量化,API 简洁
Burn 需要训练自定义模型,多后端支持
ort (ONNX) 跨平台部署(移动/嵌入式),ONNX 生态
wasm-pack 浏览器端 AI,隐私敏感场景

Candle 0.8 · Burn 0.14 · Rust 2024 Edition · 2025