ikostrikov / pytorch-a2c-ppo-acktr-gail

PyTorch implementation of Advantage Actor Critic (A2C), Proximal Policy Optimization (PPO), Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation (ACKTR) and Generative Adversarial Imitation Learning (GAIL).
MIT License
3.57k stars 829 forks source link

Possible bug on the sign of policy log prob. in Fisher computation #279

Open daniloefl opened 3 years ago

daniloefl commented 3 years ago

Dear @ikostrikov ,

while reading your code I noticed that you use the log prob. of a normal distribution for the Fisher matrix calculation in the value loss, but the negative log prob. of the policy. Comparing your code in [1] with the equivalent lines of the stable baselines [2], one can see that the policy part of the Fisher matrix calculation is a log prob. (minus negative log prob. in pg_fisher_loss) and the value function contribution is also a log prob. (minus mean squared error).

The original paper mentions the construction of the Fisher matrix using the gradient of the log prob. of the policy and the log prob. of a Gaussian around the value function (section 3.1 of [3]). I would expect therefore that the sign of the two terms used for the Fisher matrix to follow the same convention, as it is done in the stable baselines repository. The actual loss function minimisation is done with a negative log prob. for both (as you currently do, and as it is done in the stable baselines repo.), but in both cases, the sign of the two terms should be consistent.

Therefore, I could not fully understand the reason for that sign in the fisher matrix calculation. Is this a bug, or is there some deeper reason behind it?

Best regards, Danilo

[1] https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/algo/a2c_acktr.py#L53 [2] https://stable-baselines.readthedocs.io/en/master/_modules/stable_baselines/acktr/acktr.html#ACKTR [3] https://arxiv.org/pdf/1708.05144.pdf