Closed Flogg0 closed 9 months ago
Hi @Flogg0 , thanks for flagging this with a detailed bug report and repository, we'll have a fix out shortly. RE your notebook, you'll want to batch your environments to get better performance if running on a GPU. Take a look at the tutorial, in particular:
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 4096)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (1,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)
Hey all,
for a project of mine I am using the mujoco mjx pipline and try to train neural networks using JAX. We had the idea to alter the input to the physics step by having a neural network in front of the step. This necessitates the ability to backpropagate through the step function. As a spritual successor to the BRAX phyiscs pipline we thought that it should also be differentiable.
In order to verify this I set up a git repo: https://github.com/Flogg0/MJX-Backpropagation where we initialize a state
c0
and add a little bit of noise on velocity and position to get the statev0=c0+noise
. We compute the next statesc1
andv1
withmjx.step
andctrl=0
. Now we minimize the squared difference of the positions and velocities ofc1
andv1
by changingv0
. So we backpropagate throughmjx.step
to change the altered state, such that we should getc1==v1
andc0==v0
.The gradient descent works and minimizes as expected until NaN values are encountered in the gradient. Using
config.update("jax_debug_nans", True)
I was able to see that the source of the NaNs is inmujoco/mjx/_src/smooth.py:130
:--> 130 subtree_com = jp.where(cond, d.xipos, jax.vmap(jp.divide)(pos, mass))
.I suppose that the problem is, as described in https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where, the
jp.where
where both sides of the where are used for computing the gradient. I created the, imo., minimal repository: https://github.com/Flogg0/MJX-Backpropagation/tree/master, where one could replicate my findings.This does not seem like intended behavior, maybe you could fix it.
Thank you in advance!
System information: