After the changes to mjx.make_data introduced in version 3.2.3 (commit), the size of some mjx.Data members is now different after the first call to mjx.step if sparse jacobians are used. This is not ideal because it causes the following error message if you need to call jax.lax.scan on the jitted step function (see reproducing code below).
TypeError: Scanned function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:
* the input carry component data._qM_sparse has type float32[0] but the corresponding output carry component has type float32[806], so the shapes do not match
* the input carry component data._qLD_sparse has type float32[0] but the corresponding output carry component has type float32[806], so the shapes do not match
* the input carry component data._qLDiagInv_sparse has type float32[0] but the corresponding output carry component has type float32[70], so the shapes do not match
Revise the scanned function so that all output types (e.g. shapes and dtypes) match the corresponding input types.
Note that the reproducing code will run perfectly fine on versions < 3.2.3.
Intro
I am an engineer using MJX for a project at work.
My setup
Python mujoco-mjx (version >= 3.2.3) on linux
What's happening? What did you expect?
After the changes to
mjx.make_data
introduced in version 3.2.3 (commit), the size of somemjx.Data
members is now different after the first call tomjx.step
if sparse jacobians are used. This is not ideal because it causes the following error message if you need to calljax.lax.scan
on the jitted step function (see reproducing code below).Note that the reproducing code will run perfectly fine on versions < 3.2.3.
Steps for reproduction
Minimal model for reproduction
The reproducing code uses a simple inline model.
Code required for reproduction
Confirmations