google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

gradient on trajectory vs sum of gradients on 1 step transitions. #442

Closed yecohn closed 4 months ago

yecohn commented 5 months ago

Hi, I have been experimenting with brax and I am currently interested by taking the derivative of a reward function wrt to friction parameter. the reward function is given by:

state_final - state_initial / dt * <some_constant>.

Interestingly I found that computing the gradient on 1 step transition and summing the gradients for the full trajectory is different from computing the full trajectory and then taking the gradient.

does somebody knows how jax computes gradients in this case ?

thanks!

btaba commented 4 months ago

Hi @yecohn , see more info on jax autodiff for computing gradients. Why are you expecting the gradient over the full trajectory to be the sum of gradients for each step transition? The step function gets applied recursively so you need to apply the chain rule to get the gradient over the full trajectory