Closed yecohn closed 4 months ago
Hi @yecohn , see more info on jax autodiff for computing gradients. Why are you expecting the gradient over the full trajectory to be the sum of gradients for each step transition? The step
function gets applied recursively so you need to apply the chain rule to get the gradient over the full trajectory
Hi, I have been experimenting with brax and I am currently interested by taking the derivative of a reward function wrt to friction parameter. the reward function is given by:
Interestingly I found that computing the gradient on 1 step transition and summing the gradients for the full trajectory is different from computing the full trajectory and then taking the gradient.
does somebody knows how jax computes gradients in this case ?
thanks!