google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
8.19k stars 819 forks source link

mjx-accelerated gymnasium-like environment primitives (env.reset(), env.step()) without brax #1787

Closed 7oponaut closed 4 months ago

7oponaut commented 4 months ago

Hello,

I am currently working on an RL project that involves training models from mujoco_menagerie to perform certain tasks.

I already have my own working code for the whole process, which uses pytorch. I would like to accelerate the mujoco simulator using mjx, so I am planning on porting my code to jax.

I want to implement a gymnasium-like environment interface such that the batch of RL environments can be accelerated using my GPU.

I am aware of the tutorial notebook that covers mjx acceleration. This notebook uses brax for interacting with the simulator and brax's own ppo implementation. As I have said, I have my own implementation which I want to stick with. I have been browsing the mujoco and brax githubs to figure out how to implement the primitives I need for an accelerated training loop.

I need a way to 1) reset a batch of mujoco simulations with the option to control which instances to reset (with e.g. masking) 2) randomly modify the initial state configuration per simulator instance to diversify training data 3) convey batched agent actions to the simulator instances

I would like to perform these actions on the accelerated mjx objects.

The mjx tutorial uses brax, which abstracts away implementation details. I would like your assistance with figuring out how to implement the primitives I need without brax, since I would like to understand how to work with mujoco/mjx/jax directly, and I want flexibility in how my code is structured.

1) For simulator resets, I suspect that it might be enough to reset mjx_data.qpos and mjx_data.qvel to their initial values, but mujoco.mj_resetData() does a bunch of other things based on the code in src/engine/engine_io.c

How should I reset mjx_data?

2) Afaik modifying the freejoint parameters in mjx_data.qpos after reset should be sufficient to randomize the robot configuration.

Another issue: I use hfields to randomize terrain for each environment instance. In my current implementation I have a unique mj_model for each environment instance, and I can modify the mj_model.hfield_data of each environment to a unique value (of identical shape) during reset.

However, in the mjx tutorial, mjx.step() takes as input a single mjx_model and a batch of mjx_data. Since mjx_model contains the hfield_data array, there seems to be no way to set a different hfield for each environment instance. Is there a way around this?

3) Is setting the mjx_data.ctrl array the canonical way of sending control signals when using mjx?


It would be nice to have a brax-free mjx tutorial that clarifies these questions. If this already exists, please do refer me to it.

7oponaut commented 4 months ago

I did more research, reading the mujoco docs and the brax source code.

The relevant bits for basic interaction with the mjx objects seem to be at brax/mjx/pipeline.py. The class System is subclassed from mjx.Model and the class State is subclassed from mjx.Data. Looks like setting mjx_data.ctrl is the way to set the actuators like I mentioned.

I came up with some code that I think addresses my questions, including the one about how to simulate environment instances with different height maps. I am using jax for the first time so please excuse my French. In short, I batch both mjx_model and mjx_data and I use jax.vmap() to vectorize everything.

The code I am posting here is set up for testing on mujoco_menagerie/google_barkour_vb/scene_hfield_mjx.xml to make it reproducible, although I am working with a slightly different setup. You can save the code in a new file in the mujoco_menagerie repo folder and run it from there.

jax.jit() seems to be crucial for acceptable runtimes: mjx.step() is very slow without it, even on subsequent runs.

I am measuring ~1 ms runtimes on the second run for reset() and step() on my machine (RTX 4090).

import time

import jax
import jax.numpy as jnp
import jax.random as jr
import mujoco
import mujoco.mjx as mjx

def create_hfield(rng, shape) -> jnp.ndarray:
    # dummy implementation
    return jr.uniform(rng, shape, minval=0.0, maxval=1.0)

def create(xml_path: str, batch_size: int):
    mj_model = mujoco.MjModel.from_xml_path(xml_path)
    mjx_model = mjx.put_model(mj_model)
    mjx_model_batch = jax.tree.map(
        lambda x: x[None].repeat(batch_size, axis=0), mjx_model
    )
    mjx_data_batch = jax.vmap(mjx.make_data)(mjx_model_batch)
    return mjx_model, mjx_model_batch, mjx_data_batch

def reset_(key, mjx_model_batch, mjx_data_batch, reset_mask: jnp.ndarray):
    key = jr.split(key, reset_mask.shape)
    mjx_model_batch = jax.vmap(
        lambda key, reset, mjx_model: (
            jax.lax.cond(
                reset,
                lambda _: mjx_model.replace(
                    hfield_data=create_hfield(key, mjx_model.hfield_data.shape)
                ),
                lambda _: mjx_model,
                operand=None,
            )
        )
    )(key, reset_mask, mjx_model_batch)
    del key

    mjx_data_batch = jax.vmap(
        lambda reset, mjx_model, mjx_data: (
            jax.lax.cond(
                reset,
                lambda _: mjx.make_data(mjx_model),
                lambda _: mjx_data,
                operand=None,
            )
        )
    )(reset_mask, mjx_model_batch, mjx_data_batch)

    return mjx_model_batch, mjx_data_batch

def step_(mjx_model_batch, mjx_data_batch, action):
    mjx_data_batch = mjx_data_batch.replace(ctrl=action)
    mjx_data_batch = jax.vmap(mjx.step)(mjx_model_batch, mjx_data_batch)
    return mjx_data_batch

def main() -> None:
    key = jr.key(42)

    batch_size = 4096

    reset = jax.jit(reset_)
    step = jax.jit(step_)

    for idx in range(2):
        print(f"run {idx + 1}")

        t = time.perf_counter()
        mjx_model, mjx_model_batch, mjx_data_batch = create(
            "google_barkour_vb/scene_hfield_mjx.xml",
            batch_size,
        )
        print("create", time.perf_counter() - t)

        key, subkey = jr.split(key)
        t = time.perf_counter()
        mjx_model_batch, mjx_data_batch = reset(
            key,
            mjx_model_batch,
            mjx_data_batch,
            jnp.ones((batch_size,), dtype=jnp.bool),
        )
        print("reset all", time.perf_counter() - t)
        del subkey

        ctrl_range = mjx_model.actuator_ctrlrange
        key, subkey = jr.split(key)
        action = jr.uniform(
            subkey,
            (batch_size, ctrl_range.shape[0]),
            minval=ctrl_range[:, 0],
            maxval=ctrl_range[:, 1],
        )
        del subkey

        t = time.perf_counter()
        mjx_data_batch = step(mjx_model_batch, mjx_data_batch, action)
        print("step", time.perf_counter() - t)

        reset_mask = jnp.zeros((batch_size,), dtype=jnp.bool)
        reset_mask = reset_mask.at[: batch_size // 2].set(True)

        key, subkey = jr.split(key)
        t = time.perf_counter()
        mjx_model_batch, mjx_data_batch = reset(
            key, mjx_model_batch, mjx_data_batch, reset_mask
        )
        print("reset some", time.perf_counter() - t)
        del subkey

if __name__ == "__main__":
    main()

Output:

run 1
create 2.4136174280001796
reset all 0.368587808999564
step 5.961428407999847
reset some 0.0013962110001557448
run 2
create 0.10277210800040848
reset all 0.0010273100001541025
step 0.0009445460000279127
reset some 0.0007678189999751339
7oponaut commented 4 months ago

I iterated on this and came up with a solution that's more viable I think. The code has gotten a bit long to keep posting it in github comments but I will put it here for completeness.

In this version, I define init_mjx_single() and step_mjx_single() that run on single mjx.Model and mjx.Data instances. This seems more in line with jax examples I have seen. I can use these functions to take care of environment-specific business:

def _get_obs(mjx_data: mjx.Data) -> Array:
    return jnp.concatenate(
        [
            mjx_data.qpos,
            mjx_data.qvel,
            mjx_data.cinert.ravel(),
            mjx_data.cvel.ravel(),
        ]
    )

def init_mjx_single(
    key: Array, mjx_model: mjx.Model
) -> tuple[mjx.Model, mjx.Data, Array]:
    hfield_data = jr.uniform(key, mjx_model.hfield_data.shape, minval=0.0, maxval=1.0)
    mjx_model = mjx_model.replace(hfield_data=hfield_data)

    mjx_data = mjx.make_data(mjx_model)
    mjx_data = mjx.forward(mjx_model, mjx_data)
    obs = _get_obs(mjx_data)

    return mjx_model, mjx_data, obs

def step_mjx_single(
    mjx_model: mjx.Model, mjx_data: mjx.Data, action: Array
) -> tuple[mjx.Data, Array, float, bool, bool]:
    action = action.clip(min=-1.0, max=1.0)
    ctrl_range = mjx_model.actuator_ctrlrange
    ctrl_mid = ctrl_range.mean(axis=1)
    ctrl_half = ctrl_range[:, 1] - ctrl_mid
    action = ctrl_mid + ctrl_half * action

    mjx_data = mjx_data.replace(ctrl=action)
    mjx_data = mjx.step(mjx_model, mjx_data)

    obs = _get_obs(mjx_data)
    reward = 0.0
    terminated = False
    truncated = False

    return mjx_data, obs, reward, terminated, truncated

There are three more functions.

One that creates the mjx.Model batched data structure:

def create(xml_path: str, batch_size: int) -> mjx.Model:
    mj_model = mujoco.MjModel.from_xml_path(xml_path)
    mjx_model = mjx.put_model(mj_model)
    mjx_model_batch = jax.tree.map(
        lambda x: x[None].repeat(batch_size, axis=0), mjx_model
    )
    return mjx_model_batch

A batched reset() function constructed from init_mjx_single(). No need for masked resets here after all: I take care of that in the step() function.

def make_reset(init_mjx):
    def reset(key: Array, mjx_model: mjx.Model):
        batch_size = mjx_model.hfield_data.shape[0]
        key = jr.split(key, batch_size)
        mjx_model, mjx_data, obs = jax.vmap(init_mjx)(key, mjx_model)
        return mjx_model, mjx_data, obs

    return reset

A batched step() function constructed from init_mjx_single() and step_mjx_single(). This takes care of resets as well:

def make_step(init_mjx, step_mjx):
    def init_mjx_(key: Array, mjx_model: mjx.Model):
        mjx_model, mjx_data, obs = init_mjx(key, mjx_model)
        return mjx_model, mjx_data, obs, 0.0, False, False

    def step_mjx_(mjx_model: mjx.Model, mjx_data: mjx.Data, action: Array):
        mjx_data, obs, reward, terminated, truncated = step_mjx(
            mjx_model, mjx_data, action
        )
        return mjx_model, mjx_data, obs, reward, terminated, truncated

    def step(
        key: Array,
        mjx_model: mjx.Model,
        mjx_data: mjx.Data,
        terminated: Array,
        truncated: Array,
        action: Array,
    ):
        done = terminated | truncated
        key = jr.split(key, done.shape[0])
        mjx_model, mjx_data, obs, reward, terminated, truncated = jax.vmap(
            lambda key, mjx_model, mjx_data, done, action: (
                jax.lax.cond(
                    done,
                    lambda _: init_mjx_(key, mjx_model),
                    lambda _: step_mjx_(mjx_model, mjx_data, action),
                    operand=None,
                )
            )
        )(key, mjx_model, mjx_data, done, action)
        return mjx_model, mjx_data, obs, reward, terminated, truncated

    return step

Code for testing run times:

def main() -> None:
    key = jr.key(42)

    batch_size = 4096

    reset = jax.jit(make_reset(init_mjx_single))
    step = jax.jit(make_step(init_mjx_single, step_mjx_single))

    for idx in range(2):
        print(f"run {idx + 1}")

        t = time.perf_counter()
        mjx_model_batch = create(
            "mujoco_menagerie/google_barkour_vb/scene_hfield_mjx.xml",
            batch_size,
        )
        print("create", time.perf_counter() - t)

        ctrl_range = mjx_model_batch.actuator_ctrlrange[0]

        key, subkey = jr.split(key)
        t = time.perf_counter()
        mjx_model_batch, mjx_data_batch, obs_batch = reset(key, mjx_model_batch)
        print("reset", time.perf_counter() - t)
        del subkey

        key, subkey = jr.split(key)
        action = jr.uniform(
            subkey,
            (batch_size, ctrl_range.shape[0]),
            minval=-1.0,
            maxval=1.0,
        )
        del subkey

        terminated = jnp.zeros((batch_size,), dtype=jnp.bool)
        truncated = jnp.zeros((batch_size,), dtype=jnp.bool)

        key, subkey = jr.split(key)
        t = time.perf_counter()
        mjx_model_batch, mjx_data_batch, obs, reward, terminated, truncated = step(
            subkey, mjx_model_batch, mjx_data_batch, terminated, truncated, action
        )
        print("step", time.perf_counter() - t)
        del subkey

        terminated = terminated.at[: batch_size // 2].set(True)

        key, subkey = jr.split(key)
        action = jr.uniform(
            subkey,
            (batch_size, ctrl_range.shape[0]),
            minval=-1.0,
            maxval=1.0,
        )
        del subkey

        key, subkey = jr.split(key)
        t = time.perf_counter()
        mjx_model_batch, mjx_data_batch, obs, reward, terminated, truncated = step(
            subkey, mjx_model_batch, mjx_data_batch, terminated, truncated, action
        )
        print("step/reset", time.perf_counter() - t)
        del subkey

if __name__ == "__main__":
    main()

Output:

run 1
create 1.4571946640062379
reset 6.0006099570018705
step 11.913964218998444
step/reset 0.0019596930069383234
run 2
create 0.09385564400872681
reset 0.002368453991948627
step 0.0017020349914673716
step/reset 0.0014810299908276647
JeyRunner commented 4 months ago

Looks interesting :) Did you compare the step performance of the step function with and without the jax.lax.cond( for init_mjx_ ? I am not sure if cond will ensure that just the correct branch is executed (although just the correct value will be used as a result of the expression). This may cause some overhead of executing the init_mjx_ at every step.

Although in the xla documentation this described as just executing the correct branch (https://openxla.org/xla/operation_semantics#conditional). Nevertheless it would be interesting to test this, since in brax the environment is not really reset in the step function but a cached first state is applied (for performance reasons I guess).

erikfrey commented 4 months ago

@JeyRunner - you can refresh the pool of reset states at some regular interval without a performance hit, see the param num_resets_per_eval in brax's PPO trainer for an example of how to do this.

@7oponaut I'm glad you figured this out! Since it seems the issue is resolved, I'm going to go ahead and close. Feel free to open a new issue if you hit any more blockers.