Reinforce算法
目录一、Reinforce介绍二、REINFORCE baseline三、证明为啥可以降低方差1.计算策略梯度的方差2.先处理第二项​编辑3.所以上述相当于找到b优化第一项四、证明重要性质五、示例代码1.解释2.Reinforce解释3.代码一、Reinforce介绍最原始的 REINFORCE 更新公式是其中R代表Q(S,A),也就是某个轨迹的放缩reward。Reinfore的特点就是通过蒙特卡洛采样的方法采样一个轨迹之后得到Q(S,A)。上述梯度计算可能方差比较大为了降低方差引入了baseline。二、REINFORCE baseline三、证明为啥可以降低方差对于Reinforce的这个b(s),通常取一个轨迹的滑动平均。下面证明这个取法为啥可以降低方差。1.计算策略梯度的方差2.先处理第二项3.所以上述相当于找到b优化第一项这是关于 b 的二次函数对 b 求导结论最优 baseline四、证明重要性质因为梯度对参数求导与对动作求和无关五、示例代码1.解释下面的代码是使用强化学习做一个任务分配的问题机器人和任务的输入都是(x,y)的二维坐标。之后reward是-欧式距离。算法最后需要找到距离当前机器人最近的任务的策略。2.Reinforce解释由于在这个环境里面一个reward就是一个轨迹所以reward Q(S,A)baseline 就使用滑动平均替代。r -torch.norm(task_xy[action] - robot_xy) / 10.0 r_item float(r.detach().cpu().item()) baseline (1 - beta) * baseline beta * r_item advantage (r - baseline).detach()3.代码import math import random import numpy as np import torch import torch.nn as nn import os # ------------------------- # Reproducibility # ------------------------- seed 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) device cuda if torch.cuda.is_available() else cpu MODEL_PATH reinforce_cross_attn_ckpt.pth # RL ckpt # ------------------------- # Model # ------------------------- class CrossAttnChooser(nn.Module): def __init__(self, d_model32, hidden64): super().__init__() self.robot_enc nn.Sequential( nn.Linear(2, hidden), nn.Tanh(), nn.Linear(hidden, d_model), ) self.task_enc nn.Sequential( nn.Linear(2, hidden), nn.Tanh(), nn.Linear(hidden, d_model), ) self.Wq nn.Linear(d_model, d_model, biasFalse) self.Wk nn.Linear(d_model, d_model, biasFalse) self.Wv nn.Linear(d_model, d_model, biasFalse) self.post_ffn nn.Sequential( nn.Linear(d_model, hidden), nn.ReLU(), nn.Linear(hidden, d_model), ) def forward(self, robot_xy, task_xy): hr self.robot_enc(robot_xy) # (d,) ht self.task_enc(task_xy) # (N,d) Q self.Wq(hr) # (d,) K self.Wk(ht) # (N,d) V self.Wv(ht) # (N,d) attn_scores (K Q) / math.sqrt(K.shape[-1]) # (N,) a torch.softmax(attn_scores, dim0) # (N,) c a V # (d,) u self.post_ffn(c) # (d,) logits (K u) / math.sqrt(K.shape[-1]) # (N,) probs torch.softmax(logits, dim0) # (N,) return logits, probs # ------------------------- # Helpers / Env # ------------------------- def sample_tasks(n_tasks3, low-10.0, high10.0): xy np.random.uniform(low, high, size(n_tasks, 2)).astype(np.float32) return torch.tensor(xy, devicedevice) def nearest_task_index(robot_xy, task_xy): dists torch.norm(task_xy - robot_xy[None, :], dim1) return torch.argmin(dists).long() # ------------------------- # Init # ------------------------- model CrossAttnChooser(d_model32, hidden64).to(device) # ✅ RL 训练建议更小 lr避免 logits 直接推爆导致 probs[1,0,0] opt torch.optim.Adam(model.parameters(), lr2e-4) robot_xy torch.tensor([0.0, 0.0], devicedevice) total_steps 100000 print_every 1000 save_every 2000 # running baseline (EMA) baseline 0.0 beta 0.02 # logging (EMA) reward_ema 0.0 reward_beta 0.02 # ✅ 探索相关熵正则 温度 entropy_coef 0.01 # 0.005~0.05 可调越大越探索 tau 2.0 # temperature1 更平滑更探索 # ------------------------- # Load checkpoint if exists # ------------------------- start_step 0 if os.path.exists(MODEL_PATH): print(fLoading checkpoint: {MODEL_PATH}) ckpt torch.load(MODEL_PATH, map_locationdevice) model.load_state_dict(ckpt[model]) # ✅ 切 reward / 调 lr 时强烈建议不要 load optimizer动量会把你推向极端 # opt.load_state_dict(ckpt[optimizer]) # start_step int(ckpt.get(step, 0)) baseline float(ckpt.get(baseline, 0.0)) reward_ema float(ckpt.get(reward_ema, 0.0)) else: print(No checkpoint found, training from scratch.) def save_ckpt(step): torch.save({ model: model.state_dict(), optimizer: opt.state_dict(), step: step, baseline: baseline, reward_ema: reward_ema, seed: seed, tau: tau, entropy_coef: entropy_coef, }, MODEL_PATH) print(fCheckpoint saved: step{step} - {MODEL_PATH}) # ------------------------- # Training (REINFORCE baseline entropy) # ------------------------- for step in range(start_step 1, total_steps 1): task_xy sample_tasks(3) logits, _ model(robot_xy, task_xy) # ✅ 用 logits 构造分布数值更稳并用 temperature 拉平 dist torch.distributions.Categorical(logitslogits / tau) action dist.sample() logp dist.log_prob(action) # ✅ reward负欧式距离建议缩放避免 reward/advantage 过大 r -torch.norm(task_xy[action] - robot_xy) / 10.0 r_item float(r.detach().cpu().item()) baseline (1 - beta) * baseline beta * r_item advantage (r - baseline).detach() # ✅ 熵正则鼓励探索防止 probs 早早变成 [1,0,0] entropy dist.entropy() loss -(logp * advantage) - entropy_coef * entropy opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # ✅ 防止梯度爆 opt.step() reward_ema (1 - reward_beta) * reward_ema reward_beta * r_item if step % print_every 0: # 打印一次当前 probs从 logits/tau 得到 with torch.no_grad(): probs_dbg torch.softmax(logits / tau, dim0) print(Probs:, probs_dbg.detach().cpu().numpy()) print(faction{action.item()} logp{logp.item():.4f} entropy{entropy.item():.4f}) model.eval() with torch.no_grad(): correct 0 trials 1000 for _ in range(trials): txy sample_tasks(3) y_true nearest_task_index(robot_xy, txy).item() l, _ model(robot_xy, txy) p torch.softmax(l, dim0) # 评估用 tau1 更真实 y_pred torch.argmax(p).item() correct (y_pred y_true) acc correct / trials model.train() print(fstep{step:6d} r_ema{reward_ema:.4f} baseline{baseline:.4f} fadv{advantage.item():.4f} loss{loss.item():.4f} acc{acc*100:.1f}%) if step % save_every 0: save_ckpt(step) save_ckpt(total_steps) # ------------------------- # Test # ------------------------- model.eval() with torch.no_grad(): fixed_tasks torch.tensor([[-8.0, 0.0], [ 2.0, 0.0], [ 7.0, 0.0]], devicedevice) logits, _ model(robot_xy, fixed_tasks) probs torch.softmax(logits, dim0) print(\nFixed tasks:, fixed_tasks.detach().cpu().numpy()) print(Probs:, probs.detach().cpu().numpy()) print(Chosen task index:, torch.argmax(probs).item())

相关新闻

2026年毕业论文降AI总失败?可能是这3个坑你踩了

2026年毕业论文降AI总失败?可能是这3个坑你踩了

改了三遍,AI率还是55%。 我当时真的想砸电脑。明明每段都改过了,检测报告里标红的地方也全部重写了,结果AI率不降反升。 后来才搞明白,毕业论文降AI总失败,不是你不够努力,是方法根本就错了。今年检测系统…

2026/7/5 3:32:50 阅读更多 →
2026年DeepSeek写的论文AI率太高?双引擎降AI工具3分钟搞定

2026年DeepSeek写的论文AI率太高?双引擎降AI工具3分钟搞定

2026年DeepSeek写的论文AI率太高?双引擎降AI工具3分钟搞定 答辩前一周,导师把论文扔回来:“AI率92%,你当我瞎?” 我懵了。明明用DeepSeek写完之后又让它"口语化重写"了一遍,怎么还是这么高&…

2026/7/3 14:31:56 阅读更多 →
3款降AI率工具实测对比:不达标退款的那个效果意外最好

3款降AI率工具实测对比:不达标退款的那个效果意外最好

3款降AI率工具实测对比:不达标退款的那个效果意外最好 用了3款降AI率工具后,我只推荐这2个。 室友花了200块,我花了不到50块,最后我们的AI率都降到了10%以下。差别在哪?选对工具。 先说结论:嘎嘎降AI&…

2026/7/3 14:31:57 阅读更多 →

最新新闻

GBFR-Logs终极指南:从零开始掌握《碧蓝幻想:Relink》伤害统计

GBFR-Logs终极指南:从零开始掌握《碧蓝幻想:Relink》伤害统计

GBFR-Logs终极指南:从零开始掌握《碧蓝幻想:Relink》伤害统计 【免费下载链接】gbfr-logs GBFR Logs lets you track damage statistics with a nice overlay DPS meter for Granblue Fantasy: Relink. 项目地址: https://gitcode.com/gh_mirrors/gb/g…

2026/7/5 3:47:07 阅读更多 →
从团队项目角度看 AI API 聚合平台:别等成本失控后才补日志

从团队项目角度看 AI API 聚合平台:别等成本失控后才补日志

从团队项目角度看 AI API 聚合平台:别等成本失控后才补日志摘要: 很多团队第一次接入模型 API 时,关注点通常是“能不能跑通”。 但项目真正进入多人协作后,更容易出问题的是成本归属、调用日志、限流策略、错误排查和数据边界。 …

2026/7/5 3:45:06 阅读更多 →
目的:这个项目是干什么的?

目的:这个项目是干什么的?

任何一个项目都有他要实现的功能,而操作说明书就是告诉你怎么去用它,怎么去操作这些代码,这些代码提供了一个怎样的服务。如果你进到一个比较正规的公司的 话,会有测试的,有些操作你操作不了,可以求助测试…

2026/7/5 3:45:06 阅读更多 →
中小工厂零部件混采存在哪些供应链优化方式?2026 降本增效采购维度解读

中小工厂零部件混采存在哪些供应链优化方式?2026 降本增效采购维度解读

中小工厂零部件混采降本指南:2026年供应链优化的四个技术维度读者定位:本文专为中小型制造企业主、设备技术负责人及采购工程师而写,旨在解决长期困扰小批量零部件采购中的“价格高、交期长、易被拒单”的核心痛点。解决问题:本文…

2026/7/5 3:43:06 阅读更多 →
体验Managed Extensibility Framework精妙的设计

体验Managed Extensibility Framework精妙的设计

MEF(Managed Extensibility Framework)是.NET Framework 4.0一个重要的库,Visual Studio 2010 Code Editor的扩展支持也是基于MEF构建的。MEF的目标是简化创建可扩展的应用程序,其核心类是ComposablePart,即具有组合能…

2026/7/5 3:41:05 阅读更多 →
IAST实战:基于污点跟踪的Web应用漏洞精准检测与自动化集成

IAST实战:基于污点跟踪的Web应用漏洞精准检测与自动化集成

1. 项目概述:为什么大型Web应用需要IAST?如果你是一名负责大型电商、金融或SaaS平台安全测试的工程师,面对一个由数百个微服务、数千个API接口、大量JavaScript动态渲染页面构成的庞然大物,传统的漏洞扫描工具是不是经常让你感到力…

2026/7/5 3:41:05 阅读更多 →

日新闻

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools …

2026/7/5 0:03:34 阅读更多 →
威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型的陌生现状在忙碌疲惫的一天里,参与了关于混合后量子密码学的讨论,应付端点攻击找茬的人,还参与留言板讨论后,发现“威胁模型”对多数人仍是陌生概念,且多被当作时髦用语。有趣的相关画作有一幅由 Embyr 创作的…

2026/7/5 0:03:34 阅读更多 →
渗透测试入门指南:从零基础到实战环境搭建

渗透测试入门指南:从零基础到实战环境搭建

1. 从“看热闹”到“入门”:我理解的渗透测试到底是什么?每次看到新闻里说某个大公司的数据被“黑”了,或者某个网站被攻击导致服务瘫痪,你是不是和我一样,心里会冒出两个念头:一是“这黑客真厉害”&#x…

2026/7/5 0:07:38 阅读更多 →

周新闻

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools …

2026/7/5 0:03:34 阅读更多 →
威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型的陌生现状在忙碌疲惫的一天里,参与了关于混合后量子密码学的讨论,应付端点攻击找茬的人,还参与留言板讨论后,发现“威胁模型”对多数人仍是陌生概念,且多被当作时髦用语。有趣的相关画作有一幅由 Embyr 创作的…

2026/7/5 0:03:34 阅读更多 →
渗透测试入门指南:从零基础到实战环境搭建

渗透测试入门指南:从零基础到实战环境搭建

1. 从“看热闹”到“入门”:我理解的渗透测试到底是什么?每次看到新闻里说某个大公司的数据被“黑”了,或者某个网站被攻击导致服务瘫痪,你是不是和我一样,心里会冒出两个念头:一是“这黑客真厉害”&#x…

2026/7/5 0:07:38 阅读更多 →

月新闻