google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
8.26k stars 827 forks source link

`mjx.Data` sparse members change size on the first call to step (versions >= 3.2.3) #2188

Closed yeuhibd closed 3 weeks ago

yeuhibd commented 3 weeks ago

Intro

I am an engineer using MJX for a project at work.

My setup

Python mujoco-mjx (version >= 3.2.3) on linux

What's happening? What did you expect?

After the changes to mjx.make_data introduced in version 3.2.3 (commit), the size of some mjx.Data members is now different after the first call to mjx.step if sparse jacobians are used. This is not ideal because it causes the following error message if you need to call jax.lax.scan on the jitted step function (see reproducing code below).

TypeError: Scanned function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:
  * the input carry component data._qM_sparse has type float32[0] but the corresponding output carry component has type float32[806], so the shapes do not match

  * the input carry component data._qLD_sparse has type float32[0] but the corresponding output carry component has type float32[806], so the shapes do not match

  * the input carry component data._qLDiagInv_sparse has type float32[0] but the corresponding output carry component has type float32[70], so the shapes do not match

Revise the scanned function so that all output types (e.g. shapes and dtypes) match the corresponding input types.

Note that the reproducing code will run perfectly fine on versions < 3.2.3.

Steps for reproduction

  1. Run the code below.
  2. Observe error message on versions >= 3.2.3.

Minimal model for reproduction

The reproducing code uses a simple inline model.

Code required for reproduction

import mujoco
from mujoco import mjx
import jax

xml = """
<mujoco>
<worldbody>
    <body name="body1" pos="0 0 0">
        <freejoint name="base1"/>
        <inertial pos="0 0 0" mass="1" diaginertia="1 1 1"/>
    </body>
</worldbody>        
</mujoco>
"""

mj_model = mujoco.MjModel.from_xml_string(xml)
mj_model.opt.jacobian = mjx.JacobianType.SPARSE
mj_data = mujoco.MjData(mj_model)
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
step = jax.jit(lambda data, _: (mjx.step(mjx_model, data), None))
mjx_data = jax.lax.scan(f=step, init=mjx_data, length=10)

Confirmations