跳转至

05 - 分布式强化学习

学习时间: 3-4 小时 重要性: ⭐⭐⭐⭐ 大规模并行训练 前置知识: A3C 、经验回放


🎯 学习目标

完成本章后,你将能够: - 理解分布式 RL 的核心思想 - 掌握 Ape-X 架构 - 了解 IMPALA 等分布式方法 - 应用分布式训练加速学习


1. 分布式 RL 简介

1.1 动机

单机的限制: - 计算资源有限 - 训练时间长 - 样本收集慢

分布式的优势: - 并行环境交互 - 加速训练 - 处理大规模问题

1.2 架构分类

Text Only
分布式RL
├── 分布式Actor,集中式Learner
│   ├── Ape-X
│   └── R2D2
├── 分布式Actor,分布式Learner
│   ├── IMPALA
│   └── SEED
└── 参数服务器架构
    └── A3C变种

2. Ape-X

2.1 核心思想

多个 Actor ,一个 Learner: - Actor 并行收集经验 - 优先经验回放缓冲区 - Learner 从缓冲区学习

2.2 代码实现

Python
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import gymnasium as gym
from collections import deque
import random
import numpy as np

class Actor:
    """分布式Actor"""

    def __init__(self, actor_id, shared_network, replay_queue, epsilon):
        self.actor_id = actor_id
        self.network = shared_network
        self.replay_queue = replay_queue
        self.epsilon = epsilon

    def run(self, env_name):
        """Actor主循环"""
        env = gym.make(env_name)

        while True:
            state, _ = env.reset()
            done = False

            while not done:
                # ε-贪婪策略
                if random.random() < self.epsilon:
                    action = env.action_space.sample()
                else:
                    with torch.no_grad():  # 禁用梯度计算,节省内存
                        state_tensor = torch.FloatTensor(state)
                        q_values = self.network(state_tensor)
                        action = q_values.argmax().item()  # 将单元素张量转为Python数值

                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated

                # 发送经验到缓冲区
                experience = (state, action, reward, next_state, done)
                self.replay_queue.put(experience)

                state = next_state

class Learner:
    """集中式Learner"""

    def __init__(self, network, target_network, optimizer):
        self.network = network
        self.target_network = target_network
        self.optimizer = optimizer
        self.replay_buffer = PrioritizedReplayBuffer(capacity=1000000)

    def learn(self, batch_size=32):
        """从缓冲区学习"""
        if len(self.replay_buffer) < batch_size:
            return None

        # 优先级采样
        batch, indices, weights = self.replay_buffer.sample(batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)  # zip按位置配对

        # 转换为张量
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)
        weights = torch.FloatTensor(weights)

        # DQN更新
        current_q = self.network(states).gather(1, actions.unsqueeze(1))  # unsqueeze增加一个维度

        with torch.no_grad():
            next_q = self.target_network(next_states).max(1)[0]
            target_q = rewards + 0.99 * next_q * (1 - dones)

        # 加权损失
        td_errors = torch.abs(current_q.squeeze() - target_q)  # squeeze压缩维度
        loss = (weights * td_errors.pow(2)).mean()

        self.optimizer.zero_grad()  # 清零梯度
        loss.backward()  # 反向传播计算梯度
        self.optimizer.step()  # 更新参数

        # 更新优先级
        self.replay_buffer.update_priorities(indices, td_errors.detach().cpu().numpy())  # 分离计算图,确保在CPU上

        return loss.item()

3. IMPALA

3.1 核心思想

重要性加权 Actor-Learner 架构: - 多个 Actor 并行 - 多个 Learner 并行 - 使用 V-trace 修正偏差

3.2 V-trace

重要性采样比率: $\(\rho_t = \min\left(\bar{\rho}, \frac{\pi(a_t|s_t)}{\mu(a_t|s_t)}\right)\)$

V-trace 目标: $\(v_s = V(s) + \sum_{t=s}^{s+n-1} \gamma^{t-s} \left(\prod_{i=s}^{t-1} c_i\right) \rho_t \delta_t\)$

3.3 代码实现

Python
class IMPALA:
    """IMPALA算法"""

    def __init__(self, network, rho_bar=1.0, c_bar=1.0):
        self.network = network
        self.rho_bar = rho_bar
        self.c_bar = c_bar

    def vtrace(self, rewards, values, log_probs_policy, log_probs_behavior, gamma=0.99):
        """
        计算V-trace目标

        参数:
            rewards: 奖励序列
            values: 价值估计序列
            log_probs_policy: 策略对数概率
            log_probs_behavior: 行为策略对数概率
        """
        # 计算重要性采样比率
        log_rhos = log_probs_policy - log_probs_behavior
        rhos = torch.exp(log_rhos)

        # 截断比率
        clipped_rhos = torch.clamp(rhos, max=self.rho_bar)
        clipped_cs = torch.clamp(rhos, max=self.c_bar)

        # 计算V-trace目标
        deltas = clipped_rhos * (rewards + gamma * values[1:] - values[:-1])

        # 反向计算
        vs = torch.zeros_like(values)
        vs[-1] = values[-1]  # [-1]负索引取最后元素

        for t in reversed(range(len(rewards))):
            vs[t] = values[t] + deltas[t] + gamma * clipped_cs[t] * (vs[t+1] - values[t+1])

        return vs

4. 分布式方法对比

方法 Actor Learner 特点
Ape-X 优先回放
IMPALA V-trace 修正
R2D2 循环网络
SEED gRPC 通信

5. 本章总结

核心概念

Text Only
分布式RL:
├── 并行Actor: 加速样本收集
├── 优先回放: 高效利用样本
├── V-trace: 修正Off-Policy偏差
└── 可扩展性: 处理大规模问题

选择建议:
├── 研究: IMPALA
├── 工程: Ape-X
└── 大规模: SEED

✅ 自测问题

  1. 分布式 RL 相比单机 RL 有什么优势?

  2. Ape-X 中的优先级回放如何工作?

  3. V-trace 的作用是什么?


📚 延伸阅读

  1. Horgan et al. (2018) - Ape-X
  2. Espeholt et al. (2018) - IMPALA
  3. Kapturowski et al. (2019) - R2D2

→ 下一阶段:05-实战项目