Hey Brax Community,
I am not an expert on SAC, but the alpha update within the SAC implementation looks a bit weird to me. I might also be missing some mathematical identity or jax trick that resolves the misunderstanding.
In line #498, the alpha_params are defined as log_alpha. In line #343, the gradient of the alpha_loss is computed with respect to alpha. The resulting gradient is used in line line 362 & 364 to update the alpha_params using the ADAM optimizer. If I am understanding the code correctly, it seems that the log_alpha parameters are updated with the gradient with respect to alpha and not log_alpha. Or am I missing something?
If this is actually the case, it would be very easy to fix. When looking at different open-source implementations some actually implement J_alpha = log_alpha * [Expected Entropy - Desired Entropy] (e.g., Mushroom RL and MBPO) while some other implement the original Eq. 18 from the paper, i.e., J_alpha = alpha * [Expected Entropy - Desired Entropy].
Hey Brax Community, I am not an expert on SAC, but the
alpha
update within the SAC implementation looks a bit weird to me. I might also be missing some mathematical identity or jax trick that resolves the misunderstanding.In line #498, the
alpha_params
are defined aslog_alpha
. In line #343, the gradient of thealpha_loss
is computed with respect toalpha
. The resulting gradient is used in line line 362 & 364 to update thealpha_params
using the ADAM optimizer. If I am understanding the code correctly, it seems that thelog_alpha
parameters are updated with the gradient with respect toalpha
and notlog_alpha
. Or am I missing something?If this is actually the case, it would be very easy to fix. When looking at different open-source implementations some actually implement
J_alpha = log_alpha * [Expected Entropy - Desired Entropy]
(e.g., Mushroom RL and MBPO) while some other implement the original Eq. 18 from the paper, i.e.,J_alpha = alpha * [Expected Entropy - Desired Entropy]
.