google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
Apache License 2.0
8.25k stars 823 forks source link

NaN Error in Gradient Computation of MJX Simulation with Contacts #2237

Open thomascbrs opened 1 day ago

thomascbrs commented 1 day ago



We’re experimenting with the differentiable simulation features in MJX (working with @Daniellayeghi from the University of Edinburgh). While testing, we encountered an issue where computing gradients results in NaN values.

My setup

What's happening? What did you expect?


We wrote a simple example to test the gradient behavior. We are optimizing to find the optimal impulse for a capsule to collide with a sphere so that the sphere moves to the origin of the scene (mocap target). We provide an initial guess for the impulse to a simple gradient-based optimizer. We define a loss function based on the position of the second object (sphere). We apply a single control force along the x-axis to the first object (capsule) on the first time step. To optimize, when we compute the gradient of the trajectory loss (cumulative running cost) with respect to the impulse by using jax.jacrev, we consistently observe NaN values. You should be able to visualize the initial guess resulting trajectory when running the example below.

We have tested with different geometries, including Capsule, Box, and Sphere. We specifically use Capsule-Sphere in our example because the majority of our own project uses this type of collision geometry.


Is there a reason for this behavior? Are we making mistakes in our setup?

Steps for reproduction

Run the following python code.

Minimal model for reproduction

See below.

Code required for reproduction

Minimal example

python code ```python import jax import jax.numpy as jnp import numpy as np import mujoco from mujoco import mjx import os from jax import config # Jax configs config.update('jax_default_matmul_precision', 'high') # config.update("jax_debug_nans", True) # config.update("jax_enable_x64", True) def running_cost(dx): return jnp.array([dx.qpos[7]**2 + 0.00001*dx.qfrc_applied[0]**2]) @jax.vmap def simulate_trajectory_mjx(qpos_init, u): """ Simulate the impulse of the force. Return states and costs.""" def step_scan_mjx(carry, _): dx = carry dx = mjx.step(mx, dx) # Dynamics function t = jnp.expand_dims(dx.time, axis=0) cost = running_cost(dx) dx = dx.replace([:].set(jnp.zeros_like(dx.qfrc_applied))) return (dx), jnp.concatenate([dx.qpos, dx.qvel, cost, t]) dx = mjx.make_data(mx) dx = dx.replace([:].set(qpos_init)) dx = dx.replace([:].set(u)) (dx), res = jax.lax.scan(step_scan_mjx, (dx), None, length=Nlength) res, cost, t = res[...,:-2], res[...,-2], res[...,-1] return res, cost, t def compute_trajectory_costs(qpos_init, u): """ Wrapper function to compute the gradient wrt costs.""" res, cost, t = simulate_trajectory_mjx(qpos_init, u) return cost, res, t def visu_u(u0): """ Visualisation for a single force.""" def visualise(qpos, qvel): import time from mujoco import viewer data = mujoco.MjData(model) data.qpos = idata.qpos with viewer.launch_passive(model, data) as viewer: for i in range(qpos.shape[0]): step_start = time.time() data.qpos[:] = qpos[i] data.qvel[:] = qvel[i] mujoco.mj_forward(model, data) viewer.sync() time_until_next_step = model.opt.timestep - (time.time() - step_start) if time_until_next_step > 0: # time.sleep(time_until_next_step) time.sleep(0.075) qpos = jnp.expand_dims(jnp.array(idata.qpos), axis=0) qpos = jnp.repeat(qpos,1, axis = 0) u = set_u(jnp.array([u0])) res, _, _ = simulate_trajectory_mjx(qpos, u) qpos_mjx, qvel_mjx = res[0,:,:model.nq], res[0,:,model.nq:] visualise(qpos_mjx, qvel_mjx) @jax.jit def compute_loss_grad(qpos_init, u): """ Compute gradient of loss wrt u""" jac_fun = jax.jacrev(lambda x: loss(qpos_init,x)) ad_grad = jac_fun(jnp.array(u)) return ad_grad @jax.jit def loss(qpos_init, u): """ Sum of running costs""" u = set_u(u) costs = compute_trajectory_costs(qpos_init,u)[0] costs = jnp.sum(costs, axis=1) costs = jnp.mean(costs) return costs def set_u(u0): """ Utility function to initialize data""" u = jnp.zeros_like(idata.qfrc_applied) u = jnp.expand_dims(u, axis=0) u = jnp.repeat(u, u0.shape[0], axis=0) u =[:,0].set(u0) return u @jax.jit def compute_traj_grad_wrt_u(qpos_init,u): """ Gradient vector to debug the NaN occurence.""" jac_fun = jax.jacrev(lambda x: compute_trajectory_costs(qpos_init,set_u(x))[0]) ad_grad = jac_fun(jnp.array(u)) return ad_grad # Gradient descent def gradient_descent(qpos, x0, learning_rate=0.1, tol=1e-6, max_iter=100): """ Optimise initial force.""" x = x0 for i in range(max_iter): grad = compute_loss_grad(qpos, x)[0] x_new = x - learning_rate * grad # Gradient descent update print(f"Iteration {i}: x = {x_new}, f(x) = {loss(qpos, x_new)}") # Check for convergence if abs(x_new - x) < tol or jnp.isnan(grad): break x = x_new return x # The XML model as a string model_xml = """ """ if __name__ == "__main__": # Load mj and mjx model model = mujoco.MjModel.from_xml_string(model_xml) mx = mjx.put_model(model) idata = mujoco.MjData(model) qx0, qz0, qx1 = -0.375, 0.1, -0.2 # Inititial conditions idata.qpos[0],idata.qpos[2], idata.qpos[7] = qx0, qz0, qx1 Nlength = 100 # horizon lenght u0, batch = 2., 1 # Initial guess u0_jnp = jnp.array([u0]) qpos = jnp.expand_dims(jnp.array(idata.qpos), axis=0) qpos = jnp.repeat(qpos,batch, axis = 0) # Visualise guess visu_u(u0) # Run gradient descent optimal_x = gradient_descent(qpos, u0_jnp, learning_rate=0.1) # Check the gradient along the trajectory, debugging # print(compute_traj_grad_wrt_u(qpos, jnp.array([2.8]))) # Working # print(compute_traj_grad_wrt_u(qpos, jnp.array([2.9]))) # NaNs ```

Thank you !
