Closed Kallinteris-Andreas closed 3 months ago
The recommended way to do this now is: do a mjx.get_data
or mjx.put_data
, and use the get/set State API from MuJoCo. Does that not fit the use-case you have?
Adding Erik if he has more thoughts
The problem with your suggestion is that it is not JITable.
def mjx_get_physics_state_put_version(self, mjx_data: mjx._src.types.Data) -> np.ndarray:
"""Version based on @btaba suggestion."""
data = mujoco.MjData(self.model)
mjx.device_get_into(data, mjx_data)
#data = mjx.get_data(self.mjx_model, mjx_data) # TODO figure out how to use get_data instead
state = np.empty(mujoco.mj_stateSize(self.model, mujoco.mjtState.mjSTATE_PHYSICS))
mujoco.mj_getState(self.model, data, state, spec=mujoco.mjtState.mjSTATE_PHYSICS)
return state
Hi @Kallinteris-Andreas
You should call mjx_get_physics_state_put_version
outside of the jax.jit
. So once all the computations are done on device (in MJX-land), only then should you transfer the data back onto the host using device_get_inot
or get_data
, does that make sense?
Hi,
I'm a maintainer of Gymnasium & the project manager of Gymnasium-Robotics, and I'm trying to use
MuJoCo-MJX
for "prototypingMJX
-based RL environments inGymnasium
,Gymnasium-Robotics
,Metaworld
,MO-Gymnasium
".the python
mujoco
API hasmj_getState
andmj_setState
https://mujoco.readthedocs.io/en/3.1.0/APIreference/APIfunctions.html#mj-getstate example usage:but
MJX
does not have an alternative, but one can write his own easily example:is there a plan to add
mj_getState
&mj_setState
functions inmjx
, or is the user expected to write their own?Thanks!