datawhalechina / easy-rl

强化学习中文教程(蘑菇书🍄),在线阅读地址:https://datawhalechina.github.io/easy-rl/
Other
9.04k stars 1.81k forks source link

DuelingDQN.ipynb中可能存在的两个BUG~ #140

Open libermeng opened 1 year ago

libermeng commented 1 year ago
  1. 定义模型部分forward函数中return value + advantage - advantage.mean()可能有误,应该改为return value + advantage - advantage.mean(dim=1, keepdim=True)。 因为按照定义,优势网络输出的值要满足的条件应该是保持在动作维度上的和为0,那么减去的均值应该只是动作维度的均值,而不是总体的均值。
  2. 定义算法部分初始化函数中self.policy_net = model.to(self.device)self.target_net = model.to(self.device)有误,应该改成 self.policy_net = DuelingNet(cfg.n_states, cfg.n_actions, hidden_dim=cfg.hidden_dim).to(self.device)self.target_net = DuelingNet(cfg.n_states, cfg.n_actions, hidden_dim=cfg.hidden_dim).to(self.device)。 因为原初始化方式是初始化了两个相同内存地址的policy_net和target_net对象,修改后的初始化方式才是初始化两个不同内存地址的对象。
severus98 commented 2 weeks ago

附议,DDPG的初始化也存在这个问题:

self.device = torch.device(cfg['device']) self.critic = models['critic'].to(self.device) self.target_critic = models['critic'].to(self.device) self.actor = models['actor'].to(self.device) self.target_actor = models['actor'].to(self.device)

for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()): targetparam.data.copy(param.data) for target_param, param in zip(self.target_actor.parameters(), self.actor.parameters()): targetparam.data.copy(param.data)

这里的actor和target_actor,critic和target_critic实际上是对同一个网络模型的引用,并非独立的网络。

有两种改正方式:

  1. 一种如贴主所言,在初始化阶段传入类名Actor和Critic而非直接传入实例,然后分别赋参数,各实例化为独立的两个网络模型实例,然后再执行param.data.copy_参数拷贝;

  2. 另一种,可以如https://github.com/sfujim/TD3/blob/master/DDPG.py,在实例化一个网络之后,采用copy.deepcopy进行深拷贝到对应的target网络,从而创建两个独立且初始参数相同的网络模型

severus98 commented 2 weeks ago

补充一下我尝试DDPG的实验结果: 分别使用: 1)蘑菇书配套原版代码
2)蘑菇书代码基础上,使用上述的copy.deepcopy复制target网络
3) https://github.com/sfujim/TD3/blob/master/DDPG.py 训练得到网络。 然后使用同一环境种子,测试对比100回合: 1)

load_WRONG_Network_test_seed_plus100

2)

load_deepcopy_test_seed_plus100

3)

load_github_test_seed_plus100

2)和 3) 大多数时候reward在-200 以上,性能接近,而 1)大多数时候在-200以下,性能差距较大;3)比2)略好,可能与训练方式、超参数有关。

考虑到方法3)是TD3算法作者实现的源码,有一定参考性和准确性,可以说明2)这种改正网络copy的方式是有效的。