I rewrote the jax version of the SAC algorithm to the torch version, and in the reparameterization section, when calculating the loss function by sampling a uniform distribution, I found that the loss function is the same when the inputs and network weights are the same, but the gradients are different, why is this?
System info (python version, jaxlib version, accelerator, etc.)
Description
I rewrote the jax version of the SAC algorithm to the torch version, and in the reparameterization section, when calculating the loss function by sampling a uniform distribution, I found that the loss function is the same when the inputs and network weights are the same, but the gradients are different, why is this?
System info (python version, jaxlib version, accelerator, etc.)
jax==0.4.28, flax==0.8.0, torch==2.3.1