google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.38k stars 257 forks source link

Alpha Update in Soft-Actor-Critic #139

Closed milutter closed 2 years ago

milutter commented 2 years ago

Hey Brax Community, I am not an expert on SAC, but the alpha update within the SAC implementation looks a bit weird to me. I might also be missing some mathematical identity or jax trick that resolves the misunderstanding.

In line #498, the alpha_params are defined as log_alpha. In line #343, the gradient of the alpha_loss is computed with respect to alpha. The resulting gradient is used in line line 362 & 364 to update the alpha_params using the ADAM optimizer. If I am understanding the code correctly, it seems that the log_alpha parameters are updated with the gradient with respect to alpha and not log_alpha. Or am I missing something?

If this is actually the case, it would be very easy to fix. When looking at different open-source implementations some actually implement J_alpha = log_alpha * [Expected Entropy - Desired Entropy] (e.g., Mushroom RL and MBPO) while some other implement the original Eq. 18 from the paper, i.e., J_alpha = alpha * [Expected Entropy - Desired Entropy].

m-orsini commented 2 years ago

Excellent point! We're pushing a change, it should be in later today.

Thanks a lot!

erikfrey commented 2 years ago

This is pushed now - thanks Michael.

milutter commented 2 years ago

Great! Thanks for the fast fix.