We’re experimenting with the differentiable simulation features in MJX (working with @Daniellayeghi from the University of Edinburgh). While testing, we encountered an issue where computing gradients results in NaN values.
My setup
OS: Ubuntu 22.04
Python: 3.10.12
MuJoCo: 3.2.5
JAX: 0.4.33
Cuda compilation tools, release 12.6, V12.6.77
Build cuda_12.6.r12.6/compiler.34841621_0
What's happening? What did you expect?
Description
We wrote a simple example to test the gradient behavior. We are optimizing to find the optimal impulse for a capsule to collide with a sphere so that the sphere moves to the origin of the scene (mocap target). We provide an initial guess for the impulse to a simple gradient-based optimizer. We define a loss function based on the position of the second object (sphere). We apply a single control force along the x-axis to the first object (capsule) on the first time step. To optimize, when we compute the gradient of the trajectory loss (cumulative running cost) with respect to the impulse by using jax.jacrev, we consistently observe NaN values. You should be able to visualize the initial guess resulting trajectory when running the example below.
We have tested with different geometries, including Capsule, Box, and Sphere. We specifically use Capsule-Sphere in our example because the majority of our own project uses this type of collision geometry.
Comments
dt = 0.01; Initial force guess u = 2:
This guess induces a trajectory that creates a lot of contact, and optimizing over it results in NaN values during the optimization.
dt = 0.01; Initial force guess: u = 15:
This guess induces a trajectory with what looks like a single and short contact. The optimizer is able to optimize it without NaN values.
dt = 0.001:
Reducing dt = 0.001 (and proportionally increasing the initial guess) results in immediate NaNs.
solimp, solref:
We have tested different contact values (solimp, solref, friction) as well, which still gives NaN values. We noted that these influence and modify the appearance of the NaN values but still result in NaNs.
collision type:
For sphere-sphere collision cases, the NaN values are less common but still appear, and it’s unclear at what point they arise.
JIT:
Not JIT-ing any of our functions does not remove any errors.
Is there a reason for this behavior? Are we making mistakes in our setup?
Steps for reproduction
Run the following python code.
Minimal model for reproduction
See below.
Code required for reproduction
Minimal example
python code
```python
import jax
import jax.numpy as jnp
import numpy as np
import mujoco
from mujoco import mjx
import os
from jax import config
# Jax configs
config.update('jax_default_matmul_precision', 'high')
# config.update("jax_debug_nans", True)
# config.update("jax_enable_x64", True)
def running_cost(dx):
return jnp.array([dx.qpos[7]**2 + 0.00001*dx.qfrc_applied[0]**2])
@jax.vmap
def simulate_trajectory_mjx(qpos_init, u):
""" Simulate the impulse of the force. Return states and costs."""
def step_scan_mjx(carry, _):
dx = carry
dx = mjx.step(mx, dx) # Dynamics function
t = jnp.expand_dims(dx.time, axis=0)
cost = running_cost(dx)
dx = dx.replace(qfrc_applied=dx.qfrc_applied.at[:].set(jnp.zeros_like(dx.qfrc_applied)))
return (dx), jnp.concatenate([dx.qpos, dx.qvel, cost, t])
dx = mjx.make_data(mx)
dx = dx.replace(qpos=dx.qpos.at[:].set(qpos_init))
dx = dx.replace(qfrc_applied=dx.qfrc_applied.at[:].set(u))
(dx), res = jax.lax.scan(step_scan_mjx, (dx), None, length=Nlength)
res, cost, t = res[...,:-2], res[...,-2], res[...,-1]
return res, cost, t
def compute_trajectory_costs(qpos_init, u):
""" Wrapper function to compute the gradient wrt costs."""
res, cost, t = simulate_trajectory_mjx(qpos_init, u)
return cost, res, t
def visu_u(u0):
""" Visualisation for a single force."""
def visualise(qpos, qvel):
import time
from mujoco import viewer
data = mujoco.MjData(model)
data.qpos = idata.qpos
with viewer.launch_passive(model, data) as viewer:
for i in range(qpos.shape[0]):
step_start = time.time()
data.qpos[:] = qpos[i]
data.qvel[:] = qvel[i]
mujoco.mj_forward(model, data)
viewer.sync()
time_until_next_step = model.opt.timestep - (time.time() - step_start)
if time_until_next_step > 0:
# time.sleep(time_until_next_step)
time.sleep(0.075)
qpos = jnp.expand_dims(jnp.array(idata.qpos), axis=0)
qpos = jnp.repeat(qpos,1, axis = 0)
u = set_u(jnp.array([u0]))
res, _, _ = simulate_trajectory_mjx(qpos, u)
qpos_mjx, qvel_mjx = res[0,:,:model.nq], res[0,:,model.nq:]
visualise(qpos_mjx, qvel_mjx)
@jax.jit
def compute_loss_grad(qpos_init, u):
""" Compute gradient of loss wrt u"""
jac_fun = jax.jacrev(lambda x: loss(qpos_init,x))
ad_grad = jac_fun(jnp.array(u))
return ad_grad
@jax.jit
def loss(qpos_init, u):
""" Sum of running costs"""
u = set_u(u)
costs = compute_trajectory_costs(qpos_init,u)[0]
costs = jnp.sum(costs, axis=1)
costs = jnp.mean(costs)
return costs
def set_u(u0):
""" Utility function to initialize data"""
u = jnp.zeros_like(idata.qfrc_applied)
u = jnp.expand_dims(u, axis=0)
u = jnp.repeat(u, u0.shape[0], axis=0)
u = u.at[:,0].set(u0)
return u
@jax.jit
def compute_traj_grad_wrt_u(qpos_init,u):
""" Gradient vector to debug the NaN occurence."""
jac_fun = jax.jacrev(lambda x: compute_trajectory_costs(qpos_init,set_u(x))[0])
ad_grad = jac_fun(jnp.array(u))
return ad_grad
# Gradient descent
def gradient_descent(qpos, x0, learning_rate=0.1, tol=1e-6, max_iter=100):
""" Optimise initial force."""
x = x0
for i in range(max_iter):
grad = compute_loss_grad(qpos, x)[0]
x_new = x - learning_rate * grad # Gradient descent update
print(f"Iteration {i}: x = {x_new}, f(x) = {loss(qpos, x_new)}")
# Check for convergence
if abs(x_new - x) < tol or jnp.isnan(grad):
break
x = x_new
return x
# The XML model as a string
model_xml = """
"""
if __name__ == "__main__":
# Load mj and mjx model
model = mujoco.MjModel.from_xml_string(model_xml)
mx = mjx.put_model(model)
idata = mujoco.MjData(model)
qx0, qz0, qx1 = -0.375, 0.1, -0.2 # Inititial conditions
idata.qpos[0],idata.qpos[2], idata.qpos[7] = qx0, qz0, qx1
Nlength = 100 # horizon lenght
u0, batch = 2., 1 # Initial guess
u0_jnp = jnp.array([u0])
qpos = jnp.expand_dims(jnp.array(idata.qpos), axis=0)
qpos = jnp.repeat(qpos,batch, axis = 0)
# Visualise guess
visu_u(u0)
# Run gradient descent
optimal_x = gradient_descent(qpos, u0_jnp, learning_rate=0.1)
# Check the gradient along the trajectory, debugging
# print(compute_traj_grad_wrt_u(qpos, jnp.array([2.8]))) # Working
# print(compute_traj_grad_wrt_u(qpos, jnp.array([2.9]))) # NaNs
```
Intro
Hello!
We’re experimenting with the differentiable simulation features in MJX (working with @Daniellayeghi from the University of Edinburgh). While testing, we encountered an issue where computing gradients results in NaN values.
My setup
What's happening? What did you expect?
Description
We wrote a simple example to test the gradient behavior. We are optimizing to find the optimal impulse for a capsule to collide with a sphere so that the sphere moves to the origin of the scene (mocap target). We provide an initial guess for the impulse to a simple gradient-based optimizer. We define a loss function based on the position of the second object (sphere). We apply a single control force along the x-axis to the first object (capsule) on the first time step. To optimize, when we compute the gradient of the trajectory loss (cumulative running cost) with respect to the impulse by using
jax.jacrev
, we consistently observe NaN values. You should be able to visualize the initial guess resulting trajectory when running the example below.We have tested with different geometries, including Capsule, Box, and Sphere. We specifically use Capsule-Sphere in our example because the majority of our own project uses this type of collision geometry.
Comments
dt = 0.01; Initial force guess u = 2:
This guess induces a trajectory that creates a lot of contact, and optimizing over it results in NaN values during the optimization.
dt = 0.01; Initial force guess: u = 15:
This guess induces a trajectory with what looks like a single and short contact. The optimizer is able to optimize it without NaN values.
dt = 0.001:
Reducing
dt = 0.001
(and proportionally increasing the initial guess) results in immediate NaNs.solimp, solref:
We have tested different contact values (
solimp
,solref
,friction
) as well, which still gives NaN values. We noted that these influence and modify the appearance of the NaN values but still result in NaNs.collision type:
For sphere-sphere collision cases, the NaN values are less common but still appear, and it’s unclear at what point they arise.
JIT:
Not JIT-ing any of our functions does not remove any errors.
Is there a reason for this behavior? Are we making mistakes in our setup?
Steps for reproduction
Run the following python code.
Minimal model for reproduction
See below.
Code required for reproduction
Minimal example
python code
```python import jax import jax.numpy as jnp import numpy as np import mujoco from mujoco import mjx import os from jax import config # Jax configs config.update('jax_default_matmul_precision', 'high') # config.update("jax_debug_nans", True) # config.update("jax_enable_x64", True) def running_cost(dx): return jnp.array([dx.qpos[7]**2 + 0.00001*dx.qfrc_applied[0]**2]) @jax.vmap def simulate_trajectory_mjx(qpos_init, u): """ Simulate the impulse of the force. Return states and costs.""" def step_scan_mjx(carry, _): dx = carry dx = mjx.step(mx, dx) # Dynamics function t = jnp.expand_dims(dx.time, axis=0) cost = running_cost(dx) dx = dx.replace(qfrc_applied=dx.qfrc_applied.at[:].set(jnp.zeros_like(dx.qfrc_applied))) return (dx), jnp.concatenate([dx.qpos, dx.qvel, cost, t]) dx = mjx.make_data(mx) dx = dx.replace(qpos=dx.qpos.at[:].set(qpos_init)) dx = dx.replace(qfrc_applied=dx.qfrc_applied.at[:].set(u)) (dx), res = jax.lax.scan(step_scan_mjx, (dx), None, length=Nlength) res, cost, t = res[...,:-2], res[...,-2], res[...,-1] return res, cost, t def compute_trajectory_costs(qpos_init, u): """ Wrapper function to compute the gradient wrt costs.""" res, cost, t = simulate_trajectory_mjx(qpos_init, u) return cost, res, t def visu_u(u0): """ Visualisation for a single force.""" def visualise(qpos, qvel): import time from mujoco import viewer data = mujoco.MjData(model) data.qpos = idata.qpos with viewer.launch_passive(model, data) as viewer: for i in range(qpos.shape[0]): step_start = time.time() data.qpos[:] = qpos[i] data.qvel[:] = qvel[i] mujoco.mj_forward(model, data) viewer.sync() time_until_next_step = model.opt.timestep - (time.time() - step_start) if time_until_next_step > 0: # time.sleep(time_until_next_step) time.sleep(0.075) qpos = jnp.expand_dims(jnp.array(idata.qpos), axis=0) qpos = jnp.repeat(qpos,1, axis = 0) u = set_u(jnp.array([u0])) res, _, _ = simulate_trajectory_mjx(qpos, u) qpos_mjx, qvel_mjx = res[0,:,:model.nq], res[0,:,model.nq:] visualise(qpos_mjx, qvel_mjx) @jax.jit def compute_loss_grad(qpos_init, u): """ Compute gradient of loss wrt u""" jac_fun = jax.jacrev(lambda x: loss(qpos_init,x)) ad_grad = jac_fun(jnp.array(u)) return ad_grad @jax.jit def loss(qpos_init, u): """ Sum of running costs""" u = set_u(u) costs = compute_trajectory_costs(qpos_init,u)[0] costs = jnp.sum(costs, axis=1) costs = jnp.mean(costs) return costs def set_u(u0): """ Utility function to initialize data""" u = jnp.zeros_like(idata.qfrc_applied) u = jnp.expand_dims(u, axis=0) u = jnp.repeat(u, u0.shape[0], axis=0) u = u.at[:,0].set(u0) return u @jax.jit def compute_traj_grad_wrt_u(qpos_init,u): """ Gradient vector to debug the NaN occurence.""" jac_fun = jax.jacrev(lambda x: compute_trajectory_costs(qpos_init,set_u(x))[0]) ad_grad = jac_fun(jnp.array(u)) return ad_grad # Gradient descent def gradient_descent(qpos, x0, learning_rate=0.1, tol=1e-6, max_iter=100): """ Optimise initial force.""" x = x0 for i in range(max_iter): grad = compute_loss_grad(qpos, x)[0] x_new = x - learning_rate * grad # Gradient descent update print(f"Iteration {i}: x = {x_new}, f(x) = {loss(qpos, x_new)}") # Check for convergence if abs(x_new - x) < tol or jnp.isnan(grad): break x = x_new return x # The XML model as a string model_xml = """Thank you !
Confirmations