google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
8.09k stars 810 forks source link

`mjx.device_get_into` not working with `jax.jit` #1256

Closed Kallinteris-Andreas closed 10 months ago

Kallinteris-Andreas commented 10 months ago

Hi,

I'm a maintainer of Gymnasium & Gymnasium-Robotics, and I'm trying to use MuJoCo-MJX for "prototyping MJX-based RL environments in Gymnasium".

Here is the model tested: Gymnasium/HalfCheetah (though it should be not relevant for this question)

I am trying to use mjx.device_get_into with jax.jit

mjx.device_get_into(self.data, self.mjx_data)

and I get this error

...
  File "/home/master-andreas/mjx/Gymnasium-kalli/gymnasium/envs/mujoco/mujoco_env.py", line 495, in render
    mjx.device_get_into(self.data, self.mjx_data) 
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/mjx/mjx_env/lib/python3.11/site-packages/mujoco/mjx/_src/device.py", line 284, in device_get_into
    setattr(result, f.name, field_value)
TypeError: (): incompatible function arguments. The following argument types are supported:
    1. (arg0: mujoco._structs.MjData, arg1: float) -> None

Invoked with: <mujoco._structs.MjData object at 0x7f0bd9c9d830>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

note: I do not get this error without jax.jit

Is this something that is expected to happen with jax.jit and mjx.device_get_into?

@erikfrey is this an issue that will be fixed with mjx.get_data()?

Thanks!

btaba commented 10 months ago

Hi @Kallinteris-Andreas , this is entirely expected, device_get_into doesn't work with jit because mujoco.mjData is not a pytree. You can call device_get_into after your jitted function calls