Closed 7oponaut closed 4 months ago
I did more research, reading the mujoco docs and the brax source code.
The relevant bits for basic interaction with the mjx objects seem to be at brax/mjx/pipeline.py. The class System
is subclassed from mjx.Model
and the class State
is subclassed from mjx.Data
. Looks like setting mjx_data.ctrl
is the way to set the actuators like I mentioned.
I came up with some code that I think addresses my questions, including the one about how to simulate environment instances with different height maps. I am using jax for the first time so please excuse my French. In short, I batch both mjx_model
and mjx_data
and I use jax.vmap()
to vectorize everything.
The code I am posting here is set up for testing on mujoco_menagerie/google_barkour_vb/scene_hfield_mjx.xml to make it reproducible, although I am working with a slightly different setup. You can save the code in a new file in the mujoco_menagerie repo folder and run it from there.
jax.jit()
seems to be crucial for acceptable runtimes: mjx.step()
is very slow without it, even on subsequent runs.
I am measuring ~1 ms runtimes on the second run for reset()
and step()
on my machine (RTX 4090).
import time
import jax
import jax.numpy as jnp
import jax.random as jr
import mujoco
import mujoco.mjx as mjx
def create_hfield(rng, shape) -> jnp.ndarray:
# dummy implementation
return jr.uniform(rng, shape, minval=0.0, maxval=1.0)
def create(xml_path: str, batch_size: int):
mj_model = mujoco.MjModel.from_xml_path(xml_path)
mjx_model = mjx.put_model(mj_model)
mjx_model_batch = jax.tree.map(
lambda x: x[None].repeat(batch_size, axis=0), mjx_model
)
mjx_data_batch = jax.vmap(mjx.make_data)(mjx_model_batch)
return mjx_model, mjx_model_batch, mjx_data_batch
def reset_(key, mjx_model_batch, mjx_data_batch, reset_mask: jnp.ndarray):
key = jr.split(key, reset_mask.shape)
mjx_model_batch = jax.vmap(
lambda key, reset, mjx_model: (
jax.lax.cond(
reset,
lambda _: mjx_model.replace(
hfield_data=create_hfield(key, mjx_model.hfield_data.shape)
),
lambda _: mjx_model,
operand=None,
)
)
)(key, reset_mask, mjx_model_batch)
del key
mjx_data_batch = jax.vmap(
lambda reset, mjx_model, mjx_data: (
jax.lax.cond(
reset,
lambda _: mjx.make_data(mjx_model),
lambda _: mjx_data,
operand=None,
)
)
)(reset_mask, mjx_model_batch, mjx_data_batch)
return mjx_model_batch, mjx_data_batch
def step_(mjx_model_batch, mjx_data_batch, action):
mjx_data_batch = mjx_data_batch.replace(ctrl=action)
mjx_data_batch = jax.vmap(mjx.step)(mjx_model_batch, mjx_data_batch)
return mjx_data_batch
def main() -> None:
key = jr.key(42)
batch_size = 4096
reset = jax.jit(reset_)
step = jax.jit(step_)
for idx in range(2):
print(f"run {idx + 1}")
t = time.perf_counter()
mjx_model, mjx_model_batch, mjx_data_batch = create(
"google_barkour_vb/scene_hfield_mjx.xml",
batch_size,
)
print("create", time.perf_counter() - t)
key, subkey = jr.split(key)
t = time.perf_counter()
mjx_model_batch, mjx_data_batch = reset(
key,
mjx_model_batch,
mjx_data_batch,
jnp.ones((batch_size,), dtype=jnp.bool),
)
print("reset all", time.perf_counter() - t)
del subkey
ctrl_range = mjx_model.actuator_ctrlrange
key, subkey = jr.split(key)
action = jr.uniform(
subkey,
(batch_size, ctrl_range.shape[0]),
minval=ctrl_range[:, 0],
maxval=ctrl_range[:, 1],
)
del subkey
t = time.perf_counter()
mjx_data_batch = step(mjx_model_batch, mjx_data_batch, action)
print("step", time.perf_counter() - t)
reset_mask = jnp.zeros((batch_size,), dtype=jnp.bool)
reset_mask = reset_mask.at[: batch_size // 2].set(True)
key, subkey = jr.split(key)
t = time.perf_counter()
mjx_model_batch, mjx_data_batch = reset(
key, mjx_model_batch, mjx_data_batch, reset_mask
)
print("reset some", time.perf_counter() - t)
del subkey
if __name__ == "__main__":
main()
Output:
run 1
create 2.4136174280001796
reset all 0.368587808999564
step 5.961428407999847
reset some 0.0013962110001557448
run 2
create 0.10277210800040848
reset all 0.0010273100001541025
step 0.0009445460000279127
reset some 0.0007678189999751339
I iterated on this and came up with a solution that's more viable I think. The code has gotten a bit long to keep posting it in github comments but I will put it here for completeness.
In this version, I define init_mjx_single()
and step_mjx_single()
that run on single mjx.Model
and mjx.Data
instances. This seems more in line with jax examples I have seen. I can use these functions to take care of environment-specific business:
def _get_obs(mjx_data: mjx.Data) -> Array:
return jnp.concatenate(
[
mjx_data.qpos,
mjx_data.qvel,
mjx_data.cinert.ravel(),
mjx_data.cvel.ravel(),
]
)
def init_mjx_single(
key: Array, mjx_model: mjx.Model
) -> tuple[mjx.Model, mjx.Data, Array]:
hfield_data = jr.uniform(key, mjx_model.hfield_data.shape, minval=0.0, maxval=1.0)
mjx_model = mjx_model.replace(hfield_data=hfield_data)
mjx_data = mjx.make_data(mjx_model)
mjx_data = mjx.forward(mjx_model, mjx_data)
obs = _get_obs(mjx_data)
return mjx_model, mjx_data, obs
def step_mjx_single(
mjx_model: mjx.Model, mjx_data: mjx.Data, action: Array
) -> tuple[mjx.Data, Array, float, bool, bool]:
action = action.clip(min=-1.0, max=1.0)
ctrl_range = mjx_model.actuator_ctrlrange
ctrl_mid = ctrl_range.mean(axis=1)
ctrl_half = ctrl_range[:, 1] - ctrl_mid
action = ctrl_mid + ctrl_half * action
mjx_data = mjx_data.replace(ctrl=action)
mjx_data = mjx.step(mjx_model, mjx_data)
obs = _get_obs(mjx_data)
reward = 0.0
terminated = False
truncated = False
return mjx_data, obs, reward, terminated, truncated
There are three more functions.
One that creates the mjx.Model
batched data structure:
def create(xml_path: str, batch_size: int) -> mjx.Model:
mj_model = mujoco.MjModel.from_xml_path(xml_path)
mjx_model = mjx.put_model(mj_model)
mjx_model_batch = jax.tree.map(
lambda x: x[None].repeat(batch_size, axis=0), mjx_model
)
return mjx_model_batch
A batched reset()
function constructed from init_mjx_single()
. No need for masked resets here after all: I take care of that in the step()
function.
def make_reset(init_mjx):
def reset(key: Array, mjx_model: mjx.Model):
batch_size = mjx_model.hfield_data.shape[0]
key = jr.split(key, batch_size)
mjx_model, mjx_data, obs = jax.vmap(init_mjx)(key, mjx_model)
return mjx_model, mjx_data, obs
return reset
A batched step()
function constructed from init_mjx_single()
and step_mjx_single()
. This takes care of resets as well:
def make_step(init_mjx, step_mjx):
def init_mjx_(key: Array, mjx_model: mjx.Model):
mjx_model, mjx_data, obs = init_mjx(key, mjx_model)
return mjx_model, mjx_data, obs, 0.0, False, False
def step_mjx_(mjx_model: mjx.Model, mjx_data: mjx.Data, action: Array):
mjx_data, obs, reward, terminated, truncated = step_mjx(
mjx_model, mjx_data, action
)
return mjx_model, mjx_data, obs, reward, terminated, truncated
def step(
key: Array,
mjx_model: mjx.Model,
mjx_data: mjx.Data,
terminated: Array,
truncated: Array,
action: Array,
):
done = terminated | truncated
key = jr.split(key, done.shape[0])
mjx_model, mjx_data, obs, reward, terminated, truncated = jax.vmap(
lambda key, mjx_model, mjx_data, done, action: (
jax.lax.cond(
done,
lambda _: init_mjx_(key, mjx_model),
lambda _: step_mjx_(mjx_model, mjx_data, action),
operand=None,
)
)
)(key, mjx_model, mjx_data, done, action)
return mjx_model, mjx_data, obs, reward, terminated, truncated
return step
Code for testing run times:
def main() -> None:
key = jr.key(42)
batch_size = 4096
reset = jax.jit(make_reset(init_mjx_single))
step = jax.jit(make_step(init_mjx_single, step_mjx_single))
for idx in range(2):
print(f"run {idx + 1}")
t = time.perf_counter()
mjx_model_batch = create(
"mujoco_menagerie/google_barkour_vb/scene_hfield_mjx.xml",
batch_size,
)
print("create", time.perf_counter() - t)
ctrl_range = mjx_model_batch.actuator_ctrlrange[0]
key, subkey = jr.split(key)
t = time.perf_counter()
mjx_model_batch, mjx_data_batch, obs_batch = reset(key, mjx_model_batch)
print("reset", time.perf_counter() - t)
del subkey
key, subkey = jr.split(key)
action = jr.uniform(
subkey,
(batch_size, ctrl_range.shape[0]),
minval=-1.0,
maxval=1.0,
)
del subkey
terminated = jnp.zeros((batch_size,), dtype=jnp.bool)
truncated = jnp.zeros((batch_size,), dtype=jnp.bool)
key, subkey = jr.split(key)
t = time.perf_counter()
mjx_model_batch, mjx_data_batch, obs, reward, terminated, truncated = step(
subkey, mjx_model_batch, mjx_data_batch, terminated, truncated, action
)
print("step", time.perf_counter() - t)
del subkey
terminated = terminated.at[: batch_size // 2].set(True)
key, subkey = jr.split(key)
action = jr.uniform(
subkey,
(batch_size, ctrl_range.shape[0]),
minval=-1.0,
maxval=1.0,
)
del subkey
key, subkey = jr.split(key)
t = time.perf_counter()
mjx_model_batch, mjx_data_batch, obs, reward, terminated, truncated = step(
subkey, mjx_model_batch, mjx_data_batch, terminated, truncated, action
)
print("step/reset", time.perf_counter() - t)
del subkey
if __name__ == "__main__":
main()
Output:
run 1
create 1.4571946640062379
reset 6.0006099570018705
step 11.913964218998444
step/reset 0.0019596930069383234
run 2
create 0.09385564400872681
reset 0.002368453991948627
step 0.0017020349914673716
step/reset 0.0014810299908276647
Looks interesting :)
Did you compare the step performance of the step function with and without the jax.lax.cond(
for init_mjx_
?
I am not sure if cond will ensure that just the correct branch is executed (although just the correct value will be used as a result of the expression). This may cause some overhead of executing the init_mjx_
at every step.
Although in the xla documentation this described as just executing the correct branch (https://openxla.org/xla/operation_semantics#conditional). Nevertheless it would be interesting to test this, since in brax the environment is not really reset in the step function but a cached first state is applied (for performance reasons I guess).
@JeyRunner - you can refresh the pool of reset states at some regular interval without a performance hit, see the param num_resets_per_eval in brax's PPO trainer for an example of how to do this.
@7oponaut I'm glad you figured this out! Since it seems the issue is resolved, I'm going to go ahead and close. Feel free to open a new issue if you hit any more blockers.
Hello,
I am currently working on an RL project that involves training models from mujoco_menagerie to perform certain tasks.
I already have my own working code for the whole process, which uses pytorch. I would like to accelerate the mujoco simulator using mjx, so I am planning on porting my code to jax.
I want to implement a gymnasium-like environment interface such that the batch of RL environments can be accelerated using my GPU.
I am aware of the tutorial notebook that covers mjx acceleration. This notebook uses brax for interacting with the simulator and brax's own ppo implementation. As I have said, I have my own implementation which I want to stick with. I have been browsing the mujoco and brax githubs to figure out how to implement the primitives I need for an accelerated training loop.
I need a way to 1) reset a batch of mujoco simulations with the option to control which instances to reset (with e.g. masking) 2) randomly modify the initial state configuration per simulator instance to diversify training data 3) convey batched agent actions to the simulator instances
I would like to perform these actions on the accelerated mjx objects.
The mjx tutorial uses brax, which abstracts away implementation details. I would like your assistance with figuring out how to implement the primitives I need without brax, since I would like to understand how to work with mujoco/mjx/jax directly, and I want flexibility in how my code is structured.
1) For simulator resets, I suspect that it might be enough to reset
mjx_data.qpos
andmjx_data.qvel
to their initial values, butmujoco.mj_resetData()
does a bunch of other things based on the code in src/engine/engine_io.cHow should I reset
mjx_data
?2) Afaik modifying the freejoint parameters in
mjx_data.qpos
after reset should be sufficient to randomize the robot configuration.Another issue: I use hfields to randomize terrain for each environment instance. In my current implementation I have a unique
mj_model
for each environment instance, and I can modify themj_model.hfield_data
of each environment to a unique value (of identical shape) during reset.However, in the mjx tutorial,
mjx.step()
takes as input a singlemjx_model
and a batch ofmjx_data
. Sincemjx_model
contains thehfield_data
array, there seems to be no way to set a different hfield for each environment instance. Is there a way around this?3) Is setting the
mjx_data.ctrl
array the canonical way of sending control signals when using mjx?It would be nice to have a brax-free mjx tutorial that clarifies these questions. If this already exists, please do refer me to it.