google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.07k stars 2.66k forks source link

Is there difference in computing gradient between jax and torch #22101

Open ergo-zyh opened 1 week ago

ergo-zyh commented 1 week ago

Description

I rewrote the jax version of the SAC algorithm to the torch version, and in the reparameterization section, when calculating the loss function by sampling a uniform distribution, I found that the loss function is the same when the inputs and network weights are the same, but the gradients are different, why is this?

System info (python version, jaxlib version, accelerator, etc.)

jax==0.4.28, flax==0.8.0, torch==2.3.1

ayaka14732 commented 1 week ago

This sounds strange. Do you have a code example?