Closed hackertyper closed 8 months ago
Hi @hackertyper , thanks for the bug report! If you are able to pinpoint which part of the pipeline returns a NaN with https://jax.readthedocs.io/en/latest/debugging/flags.html that would be helpful to know where the grad is hitting a snag. As of yet, we haven't paid much attention to gradients in the generalized implementation
this should be fixed in 1630403
Hi,
first of all thanks for your work on brax!
I have noticed an anomaly when trying to get gradients of different pipelines.
When differentiating the pipeline functions, only the positional and spring backend yield a number as gradient. The generalized backend only yiels NaN.
The following code reproduces the issue. It simulates a free falling ball in an otherwise empty environment. The simulation is done for 100 steps. The resulting gradients are
1.0
forspring
andpositional
each butnan
forgeneralized
.Output: