PyTorch实战:用Actor-Critic算法玩转CartPole游戏(附完整代码)
从零到一用PyTorch手搓Actor-Critic让CartPole杆子屹立不倒如果你已经对强化学习的基本概念有所了解比如知道什么是智能体、环境、状态、动作和奖励并且用PyTorch写过几个简单的神经网络那么恭喜你你已经具备了“折腾”Actor-Critic演员-评论家算法的全部前置条件。很多教程一上来就大谈特谈策略梯度和时序差分理论公式看得人头昏脑胀但代码跑起来却错误百出。今天我们不搞那些虚的就从一个最经典的测试环境——CartPole倒立摆开始用PyTorch一行一行地把Actor-Critic实现出来看看这个“双网络”架构到底是怎么协同工作让一个小车学会平衡杆子的。想象一下你教一个完全不懂平衡的新手骑自行车。一开始他左摇右晃随机动作你评论家会在旁边即时评价“往左倒啦快往右打一点方向”TD误差。他演员根据你的评价调整自己的肌肉发力策略更新。经过无数次摔倒和修正他终于能稳稳骑上一段路。Actor-Critic就是这个过程的数学抽象和自动化。本文将聚焦于代码的实战实现、训练过程中的关键调试点以及如何直观地可视化学习效果目标是让你不仅能跑通代码更能理解每一行代码背后的意图从而具备修改和适配新问题的能力。1. 环境搭建与核心思想透视在动手写代码之前我们需要把“战场”布置好。CartPole-v1是OpenAI Gym现在主流是Gymnasium中的一个经典环境其状态空间是4维的连续值小车位置、速度、杆子角度、角速度动作空间是2维的离散值向左或向右施力。我们的智能体需要学习一个策略通过左右施力来防止杆子倒下。Actor-Critic之所以强大在于它巧妙地融合了两种思想策略梯度Policy Gradient直接优化策略本身适合高维或连续动作空间。但它的更新方差大学习不稳定像个容易激动、需要不断鼓励或批评的“演员”。价值函数Value Function评估状态或动作的好坏能提供更稳定、更低方差的更新信号。它像一个冷静的“评论家”负责评价演员的表现。AC算法让“演员”Actor网络负责生成动作“评论家”Critic网络负责评价当前状态的价值。演员根据评论家的评价通常是优势函数来更新自己的策略使其更倾向于选择能获得更高评价的动作。这是一种单步更新、在线学习的算法意味着它在与环境交互的每一步之后都可以进行更新学习效率相对较高。我们先来快速搭建项目的基础结构。确保你已安装必要的库pip install gymnasium torch numpy matplotlib tqdm注意本文使用gymnasium作为环境库它是OpenAI Gym的维护分支。如果你遇到与gym相关的导入错误请检查库的安装。2. 双剑合璧Actor与Critic网络架构详解Actor-Critic的核心是两个神经网络。它们结构通常不复杂但设计意图截然不同。我们分别实现它们。2.1 Actor网络策略的决策者Actor网络的任务是输入当前状态s输出每个可能动作的概率分布。对于CartPole就是输出向左和向右的概率。我们使用一个简单的两层全连接网络最后一层用Softmax激活函数来确保输出是合法的概率分布和为1。import torch import torch.nn as nn import torch.nn.functional as F class ActorNetwork(nn.Module): 策略网络Actor。 输入状态state_dim维向量 输出各个动作的概率分布action_dim维向量 def __init__(self, state_dim, hidden_dim, action_dim): super(ActorNetwork, self).__init__() self.fc1 nn.Linear(state_dim, hidden_dim) # 第一层全连接 self.fc2 nn.Linear(hidden_dim, action_dim) # 第二层全连接输出动作logits def forward(self, state): x F.relu(self.fc1(state)) # 第一层后接ReLU激活 logits self.fc2(x) # 输出原始分数logits action_probs F.softmax(logits, dim-1) # 转换为概率 return action_probs关键点解析fc2层直接输出称为logits的原始分数不包含激活函数。F.softmax(logits, dim-1)在最后一个维度即动作维度上进行Softmax运算将logits转化为概率。dim-1的写法比dim1更通用能处理不同批次大小的输入。这个网络定义了我们的策略 π(a|s)。2.2 Critic网络价值的评估者Critic网络的任务是评估当前状态s的价值 V(s)即从状态s开始遵循当前策略所能获得的期望累积回报。它也是一个回归网络输出一个标量值。class CriticNetwork(nn.Module): 价值网络Critic。 输入状态state_dim维向量 输出状态价值 V(s)标量 def __init__(self, state_dim, hidden_dim): super(CriticNetwork, self).__init__() self.fc1 nn.Linear(state_dim, hidden_dim) self.fc2 nn.Linear(hidden_dim, 1) # 输出层只有一个神经元表示价值 def forward(self, state): x F.relu(self.fc1(state)) state_value self.fc2(x) # 注意这里没有激活函数因为价值可以是任意实数 return state_value为什么Critic输出层不用激活函数状态价值 V(s) 理论上可以是任意实数正、负或零使用线性输出更合适。如果担心梯度爆炸可以在训练时使用梯度裁剪等技术。为了更清晰地对比这两个网络我们用一个表格总结特性Actor网络 (策略网络)Critic网络 (价值网络)角色决策者演员评估者评论家输入状态 s状态 s输出动作概率分布 π(a|s)状态价值估计 V(s)输出层激活Softmax (归一化为概率)无 (线性层)目标最大化期望回报最小化价值估计误差 (如TD误差)更新信号优势函数 (A)时序差分误差 (TD Error)3. 智能体核心动作选择与网络更新机制有了网络我们需要一个智能体Agent来统筹管理它们包括根据策略选择动作以及收集经验后更新网络参数。3.1 动作采样从概率到决策智能体如何根据Actor网络输出的概率分布来执行动作我们不能总是选择概率最大的动作贪心策略因为这会限制探索。我们使用依概率采样。class ActorCriticAgent: def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, device): self.actor ActorNetwork(state_dim, hidden_dim, action_dim).to(device) self.critic CriticNetwork(state_dim, hidden_dim).to(device) self.actor_optimizer torch.optim.Adam(self.actor.parameters(), lractor_lr) self.critic_optimizer torch.optim.Adam(self.critic.parameters(), lrcritic_lr) self.gamma gamma # 折扣因子 self.device device def take_action(self, state): 根据当前策略采样一个动作。 参数 state: 环境返回的状态numpy数组或列表。 返回: 动作的整型索引。 # 将状态转换为Tensor并添加批次维度 [batch_size1, state_dim] state_tensor torch.tensor([state], dtypetorch.float).to(self.device) # 获取动作概率分布 action_probs self.actor(state_tensor) # 形状 [1, action_dim] # 创建分类分布并采样 action_dist torch.distributions.Categorical(action_probs) action action_dist.sample() # 采样得到一个包含单个动作的Tensor return action.item() # 返回Python整型torch.tensor([state], ...)环境返回的state通常是np.ndarray。我们将其转换为PyTorch Tensor并用[state]将其包装成形状为[1, state_dim]的批次数据以符合网络输入要求。torch.distributions.CategoricalPyTorch提供的分类分布它接收一个概率向量和为1并允许我们从中采样、计算对数概率等非常方便。action.item()将包含单个标量的PyTorch Tensor转换为Python整数如0或1以便传入环境。3.2 核心更新TD误差与策略梯度这是AC算法最精妙的部分。我们在一轮游戏episode结束后用收集到的所有状态、动作、奖励序列来更新网络。更新基于时序差分误差。假设我们收集了一个序列的数据存储在字典transition_dict中包含states,actions,rewards,next_states,dones这几个列表。第一步数据准备与TD目标计算def update(self, transition_dict): # 将列表数据转换为Tensor并确保正确的形状 states torch.tensor(transition_dict[states], dtypetorch.float).to(self.device) actions torch.tensor(transition_dict[actions]).view(-1, 1).to(self.device) # [batch_size, 1] rewards torch.tensor(transition_dict[rewards], dtypetorch.float).view(-1, 1).to(self.device) next_states torch.tensor(transition_dict[next_states], dtypetorch.float).to(self.device) dones torch.tensor(transition_dict[dones], dtypetorch.float).view(-1, 1).to(self.device) # 计算TD目标 r γ * V(s) * (1 - done) with torch.no_grad(): # 计算TD目标时不需梯度 next_state_values self.critic(next_states) td_target rewards self.gamma * next_state_values * (1 - dones) # 计算当前状态的价值估计 state_values self.critic(states) # 计算TD误差 δ TD目标 - V(s) td_delta td_target - state_values # 这个td_delta就是优势函数A(s,a)的近似.view(-1, 1)将一维的动作、奖励、完成标志向量重塑为[batch_size, 1]的列向量方便后续的广播计算。dones表示当前状态是否为终止状态如杆子倒下。如果是终止状态done1则下一个状态的价值V(s)应为0因此乘以(1 - dones)。with torch.no_grad()在计算td_target时我们不需要对next_states的价值计算梯度因为目标值应该被视为一个固定的“标签”。这是稳定训练的关键。第二步计算损失并更新网络# 计算Actor损失策略梯度损失使用TD误差作为权重 # 首先获取所选动作的对数概率 action_probs self.actor(states) # [batch_size, action_dim] # gather(1, actions) 根据actions索引取出对应动作的概率 log_probs torch.log(action_probs.gather(1, actions)) # [batch_size, 1] # 策略梯度损失 -log(π(a|s)) * δ actor_loss torch.mean(-log_probs * td_delta.detach()) # 计算Critic损失价值估计的均方误差 critic_loss torch.mean(F.mse_loss(state_values, td_target.detach())) # 清空梯度反向传播更新参数 self.actor_optimizer.zero_grad() self.critic_optimizer.zero_grad() actor_loss.backward() critic_loss.backward() self.actor_optimizer.step() self.critic_optimizer.step()action_probs.gather(1, actions)这是一个关键操作。actions的形状是[batch_size, 1]它包含了每个状态下实际执行的动作索引。gather函数沿着第1维动作维度根据这些索引从action_probs中取出对应的概率值。td_delta.detach()在计算actor_loss时我们将td_delta从计算图中分离detach。这意味着我们只使用td_delta的数值作为权重而不让Actor的更新影响Critic的梯度计算。这是一种常用的技巧用于稳定两个网络的联合训练。F.mse_lossCritic的目标是让自己对状态的估值state_values尽可能接近td_target因此使用均方误差损失。4. 训练循环、可视化与实战调参现在我们把所有部分组装起来进入实际的训练流程。我们将编写一个训练函数并观察智能体是如何从零开始学习平衡杆子的。4.1 完整的训练流程我们使用一个简单的回合制训练循环。在每个回合中智能体与环境交互直到结束收集整个轨迹的数据然后一次性更新网络。import gymnasium as gym import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt def train_agent(env, agent, num_episodes): return_list [] # 记录每个回合的总回报 for i in range(10): # 将总回合数分成10个阶段方便用tqdm显示进度 with tqdm(totalint(num_episodes/10), descfIteration {i}) as pbar: for episode in range(int(num_episodes/10)): episode_return 0 state, _ env.reset() transition_dict { states: [], actions: [], next_states: [], rewards: [], dones: [] } done False while not done: action agent.take_action(state) next_state, reward, done, truncated, _ env.step(action) # 存储转移数据 transition_dict[states].append(state) transition_dict[actions].append(action) transition_dict[next_states].append(next_state) transition_dict[rewards].append(reward) transition_dict[dones].append(done) state next_state episode_return reward # 一个回合结束用收集到的数据更新智能体 agent.update(transition_dict) return_list.append(episode_return) # 每10个回合更新一次进度条显示 if (episode 1) % 10 0: pbar.set_postfix({ episode: f{int(num_episodes/10)*i episode 1}, return: f{np.mean(return_list[-10:]):.1f} # 显示最近10个回合的平均回报 }) pbar.update(1) return return_list4.2 启动训练与结果可视化现在让我们配置超参数创建环境与智能体并开始训练。# 超参数设置 actor_learning_rate 1e-3 # Actor网络学习率通常设小一点 critic_learning_rate 1e-2 # Critic网络学习率可以比Actor大一点让它学得更快 num_episodes 1000 # 训练总回合数 hidden_dim 128 # 网络隐藏层维度 gamma 0.98 # 折扣因子接近1表示更重视远期回报 device torch.device(cuda if torch.cuda.is_available() else cpu) # 创建环境 env_name CartPole-v1 env gym.make(env_name, render_modeNone) # 训练时不需要渲染 state_dim env.observation_space.shape[0] action_dim env.action_space.n # 创建智能体 agent ActorCriticAgent(state_dim, hidden_dim, action_dim, actor_learning_rate, critic_learning_rate, gamma, device) print(fTraining on {device}...) print(fState dim: {state_dim}, Action dim: {action_dim}) # 开始训练 returns train_agent(env, agent, num_episodes) env.close()训练完成后我们绘制学习曲线来直观评估性能。# 绘制原始回报曲线 episodes_list list(range(len(returns))) plt.figure(figsize(12, 4)) plt.subplot(1, 2, 1) plt.plot(episodes_list, returns) plt.xlabel(Episode) plt.ylabel(Return) plt.title(Raw Returns of Actor-Critic on CartPole) # 绘制滑动平均回报曲线更平滑便于观察趋势 def moving_average(data, window_size): 计算滑动平均 return np.convolve(data, np.ones(window_size)/window_size, modevalid) window_size 19 mv_returns moving_average(returns, window_size) mv_episodes list(range(window_size-1, len(returns))) plt.subplot(1, 2, 2) plt.plot(mv_episodes, mv_returns) plt.xlabel(Episode) plt.ylabel(Average Return) plt.title(fMoving Average (window{window_size}) of Returns) plt.tight_layout() plt.show()一个成功的训练其回报曲线会从最初的随机水平CartPole-v1的随机策略回报大约在20-40之间逐渐上升并稳定在接近最高分200分附近CartPole-v1的回合终止条件是杆子倒下或持续200步。4.3 关键调参经验与常见“坑点”直接运行上面的代码可能不会一次成功或者性能不佳。以下是一些实战中总结的经验学习率LR这是最重要的超参数。通常Critic的学习率可以比Actor大例如1e-2 vs 1e-3因为价值函数通常比策略更容易收敛。如果回报曲线震荡剧烈或发散尝试调小学习率。折扣因子Gammagamma0.98或0.99对于CartPole是常见选择。它控制了未来奖励的重要性。值太小会导致智能体短视值太大会使学习不稳定。网络结构对于简单环境如CartPole一层128维的隐藏层通常足够。如果学习效果不好可以尝试增加隐藏层维度如256。增加网络深度如两层隐藏层但要小心过拟合和梯度消失。梯度爆炸/消失在深度强化学习中很常见。可以尝试梯度裁剪在backward()之后、step()之前添加torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm1.0)。使用更稳定的激活函数如ReLU。探索不足如果智能体很快陷入次优策略可以尝试在take_action中引入探索机制例如ε-贪心策略以ε的概率随机选择动作或者使用熵正则化项在Actor损失中加入-beta * entropy以鼓励探索。Critic的估计不准如果Critic的价值估计误差很大会误导Actor。确保Critic有足够的能力合适的网络容量和更稳定的学习可能使用目标网络但这是更高级的技巧如A2C、PPO中常用。我在自己的机器上跑这个实验时最初将Critic的学习率设得和Actor一样1e-3发现收敛速度很慢。后来将其调整为1e-2后大约在300-400个回合后滑动平均回报就能稳定在190以上。另一个常见的错误是在计算对数概率时没有处理好张量形状导致gather函数报错务必确保actions张量是[batch_size, 1]的形状并且数据类型是torch.long。最后你可以尝试修改环境比如MountainCar-v0或Pendulum-v1注意后者是连续动作空间需要修改Actor网络输出高斯分布的参数或者尝试更先进的AC变体如带优势函数归一化的A2C或者加入了重要性采样和裁剪机制的PPO算法那将是强化学习实战旅程的下一站。

相关新闻

AMD ROCm 4.2实战:手把手教你用HIP API实现GPU内核调度(附队列优化技巧)

AMD ROCm 4.2实战:手把手教你用HIP API实现GPU内核调度(附队列优化技巧)

AMD ROCm 4.2实战:手把手教你用HIP API实现GPU内核调度(附队列优化技巧) 在异构计算的世界里,将任务高效地“投喂”给GPU,并确保其以最优的方式执行,是每个追求极致性能的开发者必须面对的课题。AMD ROCm平…

2026/5/17 5:10:21 阅读更多 →
个人创作者利器:用EasyAnimateV5图生视频模型低成本制作高质量短视频

个人创作者利器:用EasyAnimateV5图生视频模型低成本制作高质量短视频

个人创作者利器:用EasyAnimateV5图生视频模型低成本制作高质量短视频 1. 一张图,让创意动起来 你有没有过这样的时刻?拍了一张特别满意的照片,无论是精心设计的静物,还是抓拍到的动人瞬间,心里总想着&…

2026/5/17 12:09:17 阅读更多 →
QtScrcpy:跨设备协同与无感化控制的技术革命

QtScrcpy:跨设备协同与无感化控制的技术革命

QtScrcpy:跨设备协同与无感化控制的技术革命 【免费下载链接】QtScrcpy QtScrcpy 可以通过 USB / 网络连接Android设备,并进行显示和控制。无需root权限。 项目地址: https://gitcode.com/GitHub_Trending/qt/QtScrcpy 在数字化办公与多设备交互日…

2026/7/4 21:00:50 阅读更多 →

最新新闻

智能汽车板级接口与存储系统核心技术解析

智能汽车板级接口与存储系统核心技术解析

1. 智能汽车板级接口技术全景解析 作为一名在汽车电子领域深耕多年的工程师,我见证了车载电子系统从简单的ECU控制到如今复杂域控制器的演进历程。现代智能汽车的"大脑"——域控制器内部,各类芯片间的通信架构设计直接决定了系统性能上限。让我…

2026/7/5 10:37:10 阅读更多 →
AI服务合规网关实战:GDPR日志脱敏、国密SM4加密与审计追踪

AI服务合规网关实战:GDPR日志脱敏、国密SM4加密与审计追踪

1. 项目概述:一场迫在眉睫的合规风暴最近在排查一个线上AI服务的问题时,我遇到了一个典型的报错:cc switch deepseek unexpected status 502 bad gateway: unknown error, url: ht...。这个错误本身指向的是服务网关的切换或配置问题&#xf…

2026/7/5 10:35:10 阅读更多 →
光伏逆变器LVRT技术:Boost+NPC拓扑设计与控制策略

光伏逆变器LVRT技术:Boost+NPC拓扑设计与控制策略

1. 光伏逆变器低电压穿越技术概述 光伏发电系统在电网电压骤降时能否保持并网运行,直接关系到整个电力系统的稳定性。低电压穿越(LVRT)技术就是让逆变器在电网电压跌落时,不仅不脱网还能向电网提供无功功率支撑的关键能力。传统方案中,当检测…

2026/7/5 10:33:10 阅读更多 →
Allen Bradley 80190-378-51/12控制器板功能与应用解析

Allen Bradley 80190-378-51/12控制器板功能与应用解析

1. Allen Bradley 80190-378-51/12控制器板概述Allen Bradley 80190-378-51/12控制器板是罗克韦尔自动化旗下Allen-Bradley品牌推出的一款工业级控制电路板。作为自动化控制系统中的核心组件,它主要负责信号采集、逻辑运算和设备控制等功能。这款控制器板采用成熟的…

2026/7/5 10:31:10 阅读更多 →
解锁网易云音乐加密格式:ncmdump工具的全面应用指南

解锁网易云音乐加密格式:ncmdump工具的全面应用指南

解锁网易云音乐加密格式:ncmdump工具的全面应用指南 【免费下载链接】ncmdump 项目地址: https://gitcode.com/gh_mirrors/ncmd/ncmdump 你是否曾经遇到过这样的困扰:在网易云音乐下载的歌曲只能在特定应用内播放,无法在其他设备或播…

2026/7/5 10:31:10 阅读更多 →
I型NPC三电平逆变器SVPWM仿真设计与控制策略

I型NPC三电平逆变器SVPWM仿真设计与控制策略

1. I型NPC三电平逆变器SVPWM仿真设计概述在电力电子领域,三电平逆变器因其输出电压谐波含量低、开关损耗小等优势,已成为中高压大功率应用的首选拓扑结构。I型NPC(Neutral Point Clamped)三电平逆变器通过钳位二极管将直流母线中点…

2026/7/5 10:29:09 阅读更多 →

日新闻

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools …

2026/7/5 0:03:34 阅读更多 →
威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型的陌生现状在忙碌疲惫的一天里,参与了关于混合后量子密码学的讨论,应付端点攻击找茬的人,还参与留言板讨论后,发现“威胁模型”对多数人仍是陌生概念,且多被当作时髦用语。有趣的相关画作有一幅由 Embyr 创作的…

2026/7/5 0:03:34 阅读更多 →
渗透测试入门指南:从零基础到实战环境搭建

渗透测试入门指南:从零基础到实战环境搭建

1. 从“看热闹”到“入门”:我理解的渗透测试到底是什么?每次看到新闻里说某个大公司的数据被“黑”了,或者某个网站被攻击导致服务瘫痪,你是不是和我一样,心里会冒出两个念头:一是“这黑客真厉害”&#x…

2026/7/5 0:07:38 阅读更多 →

周新闻

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools …

2026/7/5 0:03:34 阅读更多 →
威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型的陌生现状在忙碌疲惫的一天里,参与了关于混合后量子密码学的讨论,应付端点攻击找茬的人,还参与留言板讨论后,发现“威胁模型”对多数人仍是陌生概念,且多被当作时髦用语。有趣的相关画作有一幅由 Embyr 创作的…

2026/7/5 0:03:34 阅读更多 →
渗透测试入门指南:从零基础到实战环境搭建

渗透测试入门指南:从零基础到实战环境搭建

1. 从“看热闹”到“入门”:我理解的渗透测试到底是什么?每次看到新闻里说某个大公司的数据被“黑”了,或者某个网站被攻击导致服务瘫痪,你是不是和我一样,心里会冒出两个念头:一是“这黑客真厉害”&#x…

2026/7/5 0:07:38 阅读更多 →

月新闻