Khrylx / PyTorch-RL

PyTorch implementation of Deep Reinforcement Learning: Policy Gradient methods (TRPO, PPO, A2C) and Generative Adversarial Imitation Learning (GAIL). Fast Fisher vector product TRPO.
MIT License
1.09k stars 186 forks source link

TRPO: KL Divergence Computation #11

Closed sandeepnRES closed 5 years ago

sandeepnRES commented 5 years ago

I see how KL divergence is computed here: def get_kl(self, x): action_prob1 = self.forward(x) action_prob0 = action_prob1.detach() kl = action_prob0 * (torch.log(action_prob0) - torch.log(action_prob1)) return kl.sum(1, keepdim=True)

Isn't this wrong? shouldn't the KL divergence be computed for new policy and old policy? Right now it seems the action_prob1, action_prob0 are same, so KL divergence will always be zero, isn't it?

Khrylx commented 5 years ago

I'm not sure what's the problem. Before the update, the new policy is equal to the old policy, so the KL is zero. Actually, the first derivative of the KL is also zero, because the KL reaches the minimum when new policy equals the old policy. But what we care about is the second derivative (Hessian) of the KL, which is not zero.