google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
7.51k stars 741 forks source link

call mj_setConst on a mjx model after changing body masses, inertias for domain randomization #1607

Open JeyRunner opened 2 months ago

JeyRunner commented 2 months ago

Hi, I am a student using mjx in combination with brax for RL. For implementing domain randomization I followed the example notebook which look like this:

def domain_randomize(sys, rng):
  """Randomizes the mjx.Model."""
  @jax.vmap
  def rand(rng):
    # randomize frictions ...
    return friction, gain, bias
  friction, gain, bias = rand(rng)

  in_axes = jax.tree_util.tree_map(lambda x: None, sys)
  in_axes = in_axes.tree_replace({
      'geom_friction': 0,
  })

  sys = sys.tree_replace({
      'geom_friction': friction,
  })

  return sys, in_axes

Now I also want to randomize over body masses and inertias, but according to https://github.com/google-deepmind/mujoco/issues/764#issuecomment-1464708759 this requires adapting dependent constants like dof_M0 ...

With the normal mujoco (non mjx) pipeline the solution seems to be calling mj_setConst after changing the link masses/inertias to have the dependent constant recalculated.

With mjx this is not possible since, to my understanding, there is no way of copying an mjx model back to a mujoco model (like an mjx.get_model function). When this function would exist, the workflow could look like this:

Or is there another way to call mj_setConst directly on a batched mjx model?

lkoelman commented 1 month ago

Bumping into the same problem. Having a solution to this issue would definitely make the domain randomisation functionality much more useful.

JeyRunner commented 1 month ago
  1. I guess the best solution would be to port everything in https://github.com/google-deepmind/mujoco/blob/1d181786a25da07a9ac536ae0e78291b221e27f0/src/engine/engine_setconst.c#L506C6-L506C17 to mjx/jax. But I am not sure how much work this would be.

  2. In the meantime, the copy model method could be used as a workaround. However, this may have significant performance disadvantages. Especially when used in an RL training loop where domain randomization is performed each reset call over many environments.

  3. Another workaround may be to call mj_setConst(mjModel* m, mjData* d) via jax.experimental.io_callback from within the jitted domain_randomize function, but this again requires some functionality to copy an mjx model to a normal mujoco model (mjx.get_model). At least this would provide the same simple API as (1.) (which can be called from within vmap) but without the need to fully port mj_setConst(mjModel* m, mjData* d).

btaba commented 1 month ago

Thanks for opening the bug @JeyRunner! Option 1 sounds the best to me. Feel free to open up a PR! Indeed this functionality would be really useful for randomizing mass properties (in a correct way).