Closed cdagher closed 2 months ago
Hi @cdagher , thanks for the bug report! This is a known issue, see a previous bug for more context https://github.com/google-deepmind/mujoco/issues/1182
If you can use the Newton solver and iterations=1
https://github.com/google-deepmind/mujoco/blob/e77c3cb2581dd1f45e30d50b31cb0a88de26a626/mjx/mujoco/mjx/_src/solver.py#L360
you should be able to get some gradients. An alternative is to swap the jax.lax.while_loop
for a scan
, which we should perhaps add as a custom option
Duplicate of #1182.
Hello,
I am working on migrating some code for a library I'm writing to MJX from Brax, and am running into issues while taking gradients. I see that this has been raised in #1349 and have found another example where the pipeline fails, documented in this Colab. I've made some adjustments to my actual code, such as removing batch functionality and making some simplifications, but this has the same behavior that I am observing.
This example consists of two systems:
The pendulum system has a NaN gradient, but as I said this is likely the same or a similar problem to what was observed in #1349 . The new issue that I am having is raising a
ValueError
while calculating the gradient. It seems that if I use the simulated trajectories in the box stair system, the compiler somehow gets confused thinking that my use of afori_loop
is with non-static bounds, raising the following error:ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.
I am using static bounds though, and using a loss function that always returns zero does not raise this error. My first thought was that it may have something to do with the maximum number of matrix inverse or line search iterations, but after lowering them both the issue was not fixed. Looking at the stack trace, it appears to be raised in the backwards pass through the mjx solver while resolving constraint forces. Maybe modifying the
while_loop
to usescan
instead would fix this, although this might come with a performance penalty ( I'm thinking about possibly always using the maximum number of iterations without early breakout).Thank you in advance!
System information: