google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
7.77k stars 771 forks source link

The CG Solver in MJX dosen't support reverse-mode differentiation #1182

Open LyuJ1998 opened 9 months ago

LyuJ1998 commented 9 months ago

I'm trying to differentiate the MJX step function via the autograd function jax.grad() in JAX, like:

def step(vel, pos):
  mjx_data = mjx.make_data(mjx_model)
  mjx_data = mjx_data.replace(qvel = vel, qpos = pos)
  pos = mjx.step(mjx_model, mjx_data).qpos
  return pos

def loss(vel, pos):
  pos = step(vel, pos)
  return jnp.sum((pos - goal_pos)**2)

grad_loss = jax.jit(jax.grad(loss))
grad = grad_loss(vel, pos)

When there is only one rigid body in the scene, everthing works, but when there is a need to solve the collision, for example, a ball and a plane in the scene

XML = """
<mujoco>
  <asset>
    <texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3"
    rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
    <material name="grid" texture="grid" texrepeat="2 2" texuniform="true"
    reflectance=".2"/>
  </asset>
  <worldbody>
    <geom name="ground" type="plane" pos="0 0 -.5" size="2 2 .1" material="grid" solimp=".99 .99 .01" solref=".001 1"/>
    <body>
      <freejoint/>
      <geom size=".15" mass="1" type="sphere"/>
    </body>
  </worldbody>
</mujoco>
"""

Error occurs:

File "/path-to-mujoco/mjx/_src/solver.py", line 347, in cg_solve
    ctx = jax.lax.while_loop(cond, body, _CGContext.create(m, d))
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

It seems the jax.lax.while() function used when solving CG do not support dynamic condition function. How can I solve this?

LyuJ1998 commented 9 months ago

I'm also trying to replace the ctx = jax.lax.while_loop(cond, body, _CGContext.create(m, d)) in mjx/_src/solver.py Line 347 with a simpler while function:

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
ctx = while_loop(cond, body, _CGContext.create(m, d))

It works when not using jax.jit() to complie the gradient function, but when using jax.jit(), another error:

File "/path-to-mujoco/mjx/_src/solver.py", line 349, in while_loop
    while cond_fun(val):
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function loss at /home/lvjun/Mujoco3/demo_mjx.py:55 for jit. This concrete value was not available in Python because it depends on the values of the arguments vel and pos.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

It is because the the improvement and gradient inf cond() is not static?

  def cond(ctx: _CGContext) -> jax.Array:
    improvement = _rescale(m, ctx.prev_cost - ctx.cost)
    gradient = _rescale(m, math.norm(ctx.grad))

    done = ctx.solver_niter >= m.opt.iterations
    done |= improvement < m.opt.tolerance
    done |= gradient < m.opt.tolerance

    return ~done

Is there any chance to make it supported for JIT compilation?

btaba commented 9 months ago

Hi @LyuJ1998 , this is a known issue with while_loop "while_loop is not reverse-mode differentiable because XLA computations require static bounds on memory requirements." https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html

You can change the while_loop to a scan . https://github.com/google/jax/discussions/3850

The TracerBoolConversionError occurs because cond_fun(val) is a traced jax array, but you're using it in a python while loop which expects a concrete value. Use a scan or a for loop

sfd158 commented 9 months ago

To reduce memory usage, there are some inplace operation in mjx.step. Inplace operation on intermediate matrics, such as X[0] += Y[0], will break back-propagation path. So jax.grad(mjx.step) doesn't work in mujoco 3.0.

erikfrey commented 9 months ago

Hi there,

Yes indeed this is by design, but poorly documented. I'll take this as motivation to add to the documentation.

tl;dr: if you would like to experiment with jax.grad(), please update to MuJoCo 3.0.1 which now includes support for Newton solver in MJX. Newton converges quickly and for many models, a single solver iteration is sufficient. If your XML looks like this:

<option ... solver="Newton" iterations="1" ls_iterations="4">

Then we omit the jax.while() and mjx.step is differentiable.

The reason we don't support this for CG is that replacing while with scan harms forward performance in some settings, so we currently accept this tradeoff.

Also please note that we have not investigated whether jax.grad delivers useful gradients in this setting - I would love to hear insights from anyone that tries this.

@sfd158 not quite sure what you mean by inplace operations - jax operations are pure, they do not modify the original. See for example this documentation on jax.numpy.ndarray.at

Andrew-Luo1 commented 5 months ago

Not sure if this is the right place to post this - Re: whether jax.grad delivers useful gradients:

I have been playing with gradients over mjx.step (Newton solver; 1 iteration) for my Masters Thesis. Please see an implementation of the Short Horizon Actor Critic (SHAC) algorithm here.

SHAC involves learning control policies using analyical policy gradients; it augments the basic Analytical Policy Gradient (APG) algorithm with several features, such as a value function. I use jax.grad to take the gradient of the loss of an environment rollout with respect to the policy parameters, and these gradients are informative enough to make the algorithm work. This works without contact (inverted pendulum) and with contact (basic hopper).

While the gradients appear to be informative in these simple cases, I can't get quadruped control working; the jacobian of the MJX step, which is a component of the gradient of the loss wrt the policy parameters, is unstable; more on this in the README of the repo. I wonder if different simulation parameters could help here, since this issue appears to have gotten worse from MJX 3.1.1 to MJX 3.1.3.

inverted_pend-ezgif com-video-to-gif-converter framed_hopper-ezgif com-video-to-gif-converter

junqingqiao commented 4 months ago

Hi there,

Yes indeed this is by design, but poorly documented. I'll take this as motivation to add to the documentation.

tl;dr: if you would like to experiment with jax.grad(), please update to MuJoCo 3.0.1 which now includes support for Newton solver in MJX. Newton converges quickly and for many models, a single solver iteration is sufficient. If your XML looks like this:

<option ... solver="Newton" iterations="1" ls_iterations="4">

Then we omit the jax.while() and mjx.step is differentiable.

The reason we don't support this for CG is that replacing while with scan harms forward performance in some settings, so we currently accept this tradeoff.

Also please note that we have not investigated whether jax.grad delivers useful gradients in this setting - I would love to hear insights from anyone that tries this.

@sfd158 not quite sure what you mean by inplace operations - jax operations are pure, they do not modify the original. See for example this documentation on jax.numpy.ndarray.at

Hi erikfrey Does setting the iterations=1 impact the simulation accuracy? Thanks, Bugman