I'm a maintainer of Gymnasium & Gymnasium-Robotics, and I'm trying to use MuJoCo-MJX for "prototyping MJX-based RL environments in Gymnasium".
Here is the model tested: Gymnasium/HalfCheetah (though it should be not relevant for this question)
I am trying to use mjx.device_get_into with jax.jit
mjx.device_get_into(self.data, self.mjx_data)
and I get this error
...
File "/home/master-andreas/mjx/Gymnasium-kalli/gymnasium/envs/mujoco/mujoco_env.py", line 495, in render
mjx.device_get_into(self.data, self.mjx_data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/master-andreas/mjx/mjx_env/lib/python3.11/site-packages/mujoco/mjx/_src/device.py", line 284, in device_get_into
setattr(result, f.name, field_value)
TypeError: (): incompatible function arguments. The following argument types are supported:
1. (arg0: mujoco._structs.MjData, arg1: float) -> None
Invoked with: <mujoco._structs.MjData object at 0x7f0bd9c9d830>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
note: I do not get this error without jax.jit
Is this something that is expected to happen with jax.jit and mjx.device_get_into?
@erikfrey is this an issue that will be fixed with mjx.get_data()?
Hi @Kallinteris-Andreas , this is entirely expected, device_get_into doesn't work with jit because mujoco.mjData is not a pytree. You can call device_get_into after your jitted function calls
Hi,
I'm a maintainer of Gymnasium & Gymnasium-Robotics, and I'm trying to use
MuJoCo-MJX
for "prototypingMJX
-based RL environments in Gymnasium".Here is the model tested: Gymnasium/HalfCheetah (though it should be not relevant for this question)
I am trying to use
mjx.device_get_into
withjax.jit
and I get this error
note: I do not get this error without
jax.jit
Is this something that is expected to happen with
jax.jit
andmjx.device_get_into
?@erikfrey is this an issue that will be fixed with
mjx.get_data()
?Thanks!