Closed acezsq closed 6 months ago
你的理解是没错的,应该是我代码出现了问题。logp_old_con最后应该是由act_dis index 的,至于这一步在哪里处理没有太大所谓。原本那个文件是两年前的,我已经把正确的更新上去了(我经常更新的在一个线下仓库,改着改着这边忘记同步了,非常抱歉,隐隐感觉是个雷,这个仓库做个参考就好,有问题可以问我,我毕业设计搞完应该会把完整的放上来)return state_value, action_dis, action_con[action_dis], logprob_dis, logprob_con[action_dis]。
感谢回复,关于HPPO的实现还有个额外的问题想请教一下,就是比如我的环境是三个离散动作,两个连续动作,比方说下面这个链接里的环境https://di-engine-docs.readthedocs.io/zh-cn/latest/13_envs/gym_hybrid.html 这种环境下logp应该如何计算,如果离散动作选择到没有对应连续动作那个,这个时候好像没办法计算logp了,hppo算法针对这种应该怎么处理? 感谢您能百忙之中回复!
根据我的理解,discrete_actor三个heads;continuous_actor 两个heads,选到turn或者是accelerate就用logp_con = logp_con.gather(1, act_dis.view(-1, 1)).squeeze()
去更新相应的连续网络就行,选到break,那么连续网络就什么都不做
根据我的理解,discrete_actor三个heads;continuous_actor 两个heads,选到turn或者是accelerate就用
logp_con = logp_con.gather(1, act_dis.view(-1, 1)).squeeze()
去更新相应的连续网络就行,选到break,那么连续网络就什么都不做
感谢回复!我准备在更新时离散网络用一个batch的数据更新,连续网络更新时把batch中离散动作为2的数据剔除,然后用剔除后batch中剩下的数据更新连续网络。
我觉得最核心的一行代码是:
logp_con = logp_con.gather(1, act_dis.view(-1, 1)).squeeze()
。关于这行代码我的理解是比如三个离散动作,每个离散动作对应一个连续的动作,这行代码是通过logp_con
找到对应的离散动作的logp
,比如logp_con
是(0.52,0.78,0.96),act_dis
是1,所以选择0.78作为最终的logp_con
。 这也就跟在class PPOBuffer:
中定义self.logp_con_buf = np.zeros(size, dtype=np.float32)
的shape
对应,然而在ratio_con = torch.exp(logp_con - logp_old_con)
中的logp_old_con
我没有在您的代码中找到类似的对logp_old_con
进行gather操作进而降低维度的操作,也就是说logp_old_con
的是(0.52,0.78,0.96)而不是一个数,这两者的shape
也就不一致,没法进行计算。 在class ActorCritic_Hybrid(nn.Module):
类的实现中def act(self, state):
的定义中也没有看到根据离散动作找到对应连续动作求logp_con
的操作。 希望得到解答。