google-deepmind / rlax

https://rlax.readthedocs.io
Apache License 2.0
1.26k stars 87 forks source link

Stop action gradient in policy gradient loss #109

Open danijar opened 2 years ago

danijar commented 2 years ago

The current implementation of policy_gradient_loss is:

log_pi_a_t = distributions.softmax().logprob(a_t, logits_t)
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
loss_per_timestep = -log_pi_a_t * adv_t

It's good that the gradients are already stopped around the advantages, but they should also be stopped around the actions to ensure an unbiased gradient estimator.

This is important when the actions are sampled as part of the training graph (MPO-style algos, imagination training with world models) rather than coming from the replay buffer, and the actor distribution implements a gradient for sample() (e.g. gaussian, or straight-through categoricals).

alirezakazemipour commented 1 year ago

Hi @danijar It might be a bit irrelevant to the issue but could you please specify what you mean by "MPO-style algorithms"?

I understood your point of preventing gradients to flow through actions in the computation graph. Just didn't understand what you mean as a MPO-style algorithm. Thank you in advance. 🙏

danijar commented 1 year ago

Algorithms that sample multiple actions per replayed transition and perform loss-weighted regression on them based on some performance score computed from the critic.

alirezakazemipour commented 1 year ago

Thank you @danijar! If I'm not mistaken, therefore, MPO-Style algos are more prevalent in the context of Model-Based RL especially, at the Planning stage. Nonetheless, as you have also underlined, I believe the issue you have raised points to a more general and broader class of methods that are on the basis of policy gradients (say REINFORCE) and the gradients must be computed w.r.t the current iteration's parameters and if actions carry gradients from previous iterations then they will be accumulated and thus making the optimization invalid. Though, I don't know if I described correctly then how the final gradient estimation would be biased! I mean biased towards what? Thank you again! :pray:

danijar commented 1 year ago

MPO and VMPO are model free algorithms, you can think of it as an extension of DDPG to stochastic policies, where you're only using forward pass information of the critic and not its gradients. I don't know what it's biased towards, but the implementation will not estimate the correct gradient (perhaps the rough direction of the gradient is still good because it did train in my case, just to worse performance).

alirezakazemipour commented 1 year ago

@danijar Yeah, I googled them and found their corresponding papers! Thank you so much for your explanations and I do apologize if I intervened and brought an off-topic question up here. ❤️

hbq1 commented 1 year ago

@danijar thanks for reporting this and apologies for the delayed response.

I guess we may want gradient to flow through actions in some cases. We can add an optional argument, but probably it's better if stop_gradient is placed on the arguments of policy_gradient_loss() in the calling code, WDYT?