google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
7.85k stars 786 forks source link

[MJX] JAX.grad of mjx_step is producing NaNs #1349

Closed Flogg0 closed 7 months ago

Flogg0 commented 8 months ago

Hey all,

for a project of mine I am using the mujoco mjx pipline and try to train neural networks using JAX. We had the idea to alter the input to the physics step by having a neural network in front of the step. This necessitates the ability to backpropagate through the step function. As a spritual successor to the BRAX phyiscs pipline we thought that it should also be differentiable.

In order to verify this I set up a git repo: https://github.com/Flogg0/MJX-Backpropagation where we initialize a state c0 and add a little bit of noise on velocity and position to get the state v0=c0+noise. We compute the next states c1 and v1 with mjx.step and ctrl=0. Now we minimize the squared difference of the positions and velocities of c1 and v1 by changing v0. So we backpropagate through mjx.step to change the altered state, such that we should get c1==v1 and c0==v0.

Unbenannt

The gradient descent works and minimizes as expected until NaN values are encountered in the gradient. Using config.update("jax_debug_nans", True) I was able to see that the source of the NaNs is in mujoco/mjx/_src/smooth.py:130: --> 130 subtree_com = jp.where(cond, d.xipos, jax.vmap(jp.divide)(pos, mass)).

I suppose that the problem is, as described in https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where, the jp.where where both sides of the where are used for computing the gradient. I created the, imo., minimal repository: https://github.com/Flogg0/MJX-Backpropagation/tree/master, where one could replicate my findings.

This does not seem like intended behavior, maybe you could fix it.

Thank you in advance!

System information:

btaba commented 7 months ago

Hi @Flogg0 , thanks for flagging this with a detailed bug report and repository, we'll have a fix out shortly. RE your notebook, you'll want to batch your environments to get better performance if running on a GPU. Take a look at the tutorial, in particular:

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 4096)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (1,))))(rng)

jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)