google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
7.79k stars 775 forks source link

Domain randomization for MJX #1684

Closed alexhansson closed 5 days ago

alexhansson commented 3 months ago

Hi,

I'm a student and I'm trying to use MuJoCo for training RL policies for space systems. According to the google collab notebook linked in the MuJoCo documentation, one can simulate and randomize multiple environments in parallel with this code:

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 4096)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (1,))))(rng)

jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
batch = jit_step(mjx_model, batch)

However, this only varies the initial position of your object. I'm wondering if you can vary the environments more flexibly, like for example randomizing the location/dimension of collision objects in the scene. I know certain parameters can't be changed once the xml has been generated, but would it in theory be possible to have multiple xml files and "concatenate" them in some way for parallel simulation?

And as a side question, is it possible to visualize the parallel environments similarly to issacgym?

Thanks a lot for your help.

btaba commented 3 months ago

Hi @alexhansson , in the same notebook, there is a demo varying other parameters in the model to train an RL policy with domain randomization. Search for "Training a Policy with Domain Randomization".

RE rendering: There isn't an out-of-the-box way to render parallel environments. Here are some relevant discussions (some of which you have already commented on)

https://github.com/google-deepmind/mujoco/issues/1682 https://github.com/google-deepmind/mujoco/issues/1604 https://github.com/google-deepmind/mujoco/issues/1356

alexhansson commented 3 months ago

Thanks for answering @btaba

I have seen the "Training a Policy with Domain Randomization" section. However, only some parameters like friction and actuator gain/bias are randomized.
In our application we want to simulate multiple environments in parallel where we vary number, shape and size of collision objects which our agent has to maneuver. I see two ways how this could be possible (but I don't know if they work like this):

  1. Create a xml file parser and "combine" multiple, randomized environments (i.e. convert each xml in model and data object) in some sort of array for parallel simulation.
  2. Add properties like positions (but not shape or size) of collision objects spread across the environment.

Could you provide more insight on this? I'm also happy to explain this in more detail if needed.

RE rendering: Thanks for linking the relevant issues. I take from it that you're working on it, which is very cool.

Thanks for your time!

btaba commented 3 months ago

Hi @alexhansson, varying shape and size of collision objects can be done similarly to the demo in the colab. See https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-geom-size as an example. With meshes, this is a bit harder, but will be viable (working on a current commit), see the relevant issue: https://github.com/google-deepmind/mujoco/issues/1655

To vary the number of objects, [1] and [2] both sound like good ideas. For [1] you won't be able to stack Model's if they have different shapes, so you'll have to dispatch to each scene on the host. For [2] you'd load all the objects into the same environment, and "mask" objects you don't want to interact with at reset by potentially moving those objects out of the main scene. Hope this helps for now

alexhansson commented 3 months ago

Hey @btaba Sounds good. So I could do the following: Create one environment with 1 agent and M collision objects/shape. Before running the episodes I would randomize the position, shape and size (so following [2]) for N parallel environments. And after that I am able to execute the envs in parallel with the M obstacles placed at randomized locations?

btaba commented 3 months ago

Hey @alexhansson ,

Sounds good for resetting the object positions at environment reset time. The positions are part of the mjx.Data which is part of the environment State.

For randomizing the size/shape (esp. easy for primitives), I would do that with the domain randomization utility on geom_size, which modifies the mjx.Model (so that each parallelized env gets a different size). You can follow the example in the colab. Let us know if you have any trouble.

alexhansson commented 3 months ago

Sounds good. Will reply in this thread if there are any issues down the line. Thanks!

Edit: One more question. With reset time, what do you mean exactly by that? Do you mean the custom method for when i call env.reset() or are you referring to a built-in one? @btaba

btaba commented 3 months ago

Hey @alexhansson , by reset time, I mean in the environment reset function. The relevant example in the tutorial is how a sample_command is randomly sampled during the environment reset. Also notice how pipeline_state = self.pipeline_init(self._init_q, jp.zeros(self._nv)) sets the position of the joints during the reset, but you could equally add some noise to init_q and get random starting joint positions.