google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
8.19k stars 819 forks source link

[MJX] Graedient of mjx_step is NaN or raises ValueError #1377

Closed cdagher closed 2 months ago

cdagher commented 9 months ago

Hello,

I am working on migrating some code for a library I'm writing to MJX from Brax, and am running into issues while taking gradients. I see that this has been raised in #1349 and have found another example where the pipeline fails, documented in this Colab. I've made some adjustments to my actual code, such as removing batch functionality and making some simplifications, but this has the same behavior that I am observing.

This example consists of two systems:

The pendulum system has a NaN gradient, but as I said this is likely the same or a similar problem to what was observed in #1349 . The new issue that I am having is raising a ValueError while calculating the gradient. It seems that if I use the simulated trajectories in the box stair system, the compiler somehow gets confused thinking that my use of a fori_loop is with non-static bounds, raising the following error: ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

I am using static bounds though, and using a loss function that always returns zero does not raise this error. My first thought was that it may have something to do with the maximum number of matrix inverse or line search iterations, but after lowering them both the issue was not fixed. Looking at the stack trace, it appears to be raised in the backwards pass through the mjx solver while resolving constraint forces. Maybe modifying the while_loop to use scan instead would fix this, although this might come with a performance penalty ( I'm thinking about possibly always using the maximum number of iterations without early breakout).

Thank you in advance!

System information:

btaba commented 9 months ago

Hi @cdagher , thanks for the bug report! This is a known issue, see a previous bug for more context https://github.com/google-deepmind/mujoco/issues/1182

If you can use the Newton solver and iterations=1 https://github.com/google-deepmind/mujoco/blob/e77c3cb2581dd1f45e30d50b31cb0a88de26a626/mjx/mujoco/mjx/_src/solver.py#L360

you should be able to get some gradients. An alternative is to swap the jax.lax.while_loop for a scan, which we should perhaps add as a custom option

kevinzakka commented 2 months ago

Duplicate of #1182.