google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
8.27k stars 827 forks source link

[`MJX`] `mj_getState` and `mj_setState` equivalents for `mjx` #1283

Closed Kallinteris-Andreas closed 3 months ago

Kallinteris-Andreas commented 11 months ago

Hi,

I'm a maintainer of Gymnasium & the project manager of Gymnasium-Robotics, and I'm trying to use MuJoCo-MJX for "prototyping MJX-based RL environments in Gymnasium, Gymnasium-Robotics, Metaworld, MO-Gymnasium".

the python mujoco API has mj_getState and mj_setState https://mujoco.readthedocs.io/en/3.1.0/APIreference/APIfunctions.html#mj-getstate example usage:

state = np.empty(mujoco.mj_stateSize(env.unwrapped.model, mujoco.mjtState.mjSTATE_PHYSICS))
mujoco.mj_getState(env.unwrapped.model, env.unwrapped.data, state, spec=mujoco.mjtState.mjSTATE_PHYSICS)

mujoco.mj_setState(env.unwrapped.model, env.unwrapped.data, state, spec=mujoco.mjtState.mjSTATE_PHYSICS)

but MJX does not have an alternative, but one can write his own easily example:

# TODO unit test these
def mjx_get_physics_state(mjx_data: mjx._src.types.Data) -> jnp.ndarray:
    """Get physics state of `mjx_data` similar to mujoco.get_state."""
    return jnp.concatenate([mjx_data.qpos, mjx_data.qvel, mjx_data.act])

def mjx_set_physics_state(mjx_data: mjx._src.types.Data, mjx_physics_state) -> mjx._src.types.Data:
    """Sets the physics state in `mjx_data`."""
    qpos_end_index = mjx_data.qpos.size
    qvel_end_index = qpos_end_index + mjx_data.qvel.size

    qpos = mjx_physics_state[:qpos_end_index]
    qvel = mjx_physics_state[qpos_end_index: qvel_end_index]
    act = mjx_physics_state[qvel_end_index:]
    assert qpos.size == mjx_data.qpos.size
    assert qvel.size == mjx_data.qvel.size
    assert act.size == mjx_data.act.size

    return mjx_data.replace(qpos=qpos, qvel=qvel, act=act)

is there a plan to add mj_getState & mj_setState functions in mjx, or is the user expected to write their own?

Thanks!

btaba commented 11 months ago

The recommended way to do this now is: do a mjx.get_data or mjx.put_data, and use the get/set State API from MuJoCo. Does that not fit the use-case you have? Adding Erik if he has more thoughts

Kallinteris-Andreas commented 11 months ago

The problem with your suggestion is that it is not JITable.

    def mjx_get_physics_state_put_version(self, mjx_data: mjx._src.types.Data) -> np.ndarray:
        """Version based on @btaba suggestion."""
        data = mujoco.MjData(self.model)
        mjx.device_get_into(data, mjx_data)
        #data = mjx.get_data(self.mjx_model, mjx_data)   # TODO figure out how to use get_data instead
        state = np.empty(mujoco.mj_stateSize(self.model, mujoco.mjtState.mjSTATE_PHYSICS))
        mujoco.mj_getState(self.model, data, state, spec=mujoco.mjtState.mjSTATE_PHYSICS)

        return state
btaba commented 10 months ago

Hi @Kallinteris-Andreas

You should call mjx_get_physics_state_put_version outside of the jax.jit. So once all the computations are done on device (in MJX-land), only then should you transfer the data back onto the host using device_get_inot or get_data, does that make sense?