Open JeyRunner opened 7 months ago
Bumping into the same problem. Having a solution to this issue would definitely make the domain randomisation functionality much more useful.
I guess the best solution would be to port everything in https://github.com/google-deepmind/mujoco/blob/1d181786a25da07a9ac536ae0e78291b221e27f0/src/engine/engine_setconst.c#L506C6-L506C17 to mjx/jax. But I am not sure how much work this would be.
In the meantime, the copy model method could be used as a workaround. However, this may have significant performance disadvantages. Especially when used in an RL training loop where domain randomization is performed each reset call over many environments.
Another workaround may be to call mj_setConst(mjModel* m, mjData* d)
via jax.experimental.io_callback
from within the jitted domain_randomize
function, but this again requires some functionality to copy an mjx model to a normal mujoco model (mjx.get_model
). At least this would provide the same simple API as (1.) (which can be called from within vmap) but without the need to fully port mj_setConst(mjModel* m, mjData* d)
.
Thanks for opening the bug @JeyRunner! Option 1 sounds the best to me. Feel free to open up a PR! Indeed this functionality would be really useful for randomizing mass properties (in a correct way).
Having the same question. Are there any updates on this?
@hshi74 Apologies there hasn't been movement on this issue yet.
As a workaround, I recommend creating the randomized models in MuJoCo so that fields get compiled correctly, and doing a put_model
on each randomized model. You can then tree_map/concat the models in MJX. During training, avoid calling resets on the mjx.Model
(for fields that need to be re-compiled). We find that many policies do fine with a large batch of reset states that get set once at the beginning of training. Let us know if there are any updates otherwise!
Hi, I am a student using mjx in combination with brax for RL. For implementing domain randomization I followed the example notebook which look like this:
Now I also want to randomize over body masses and inertias, but according to https://github.com/google-deepmind/mujoco/issues/764#issuecomment-1464708759 this requires adapting dependent constants like
dof_M0
...With the normal mujoco (non mjx) pipeline the solution seems to be calling
mj_setConst
after changing the link masses/inertias to have the dependent constant recalculated.With mjx this is not possible since, to my understanding, there is no way of copying an mjx model back to a mujoco model (like an
mjx.get_model
function). When this function would exist, the workflow could look like this:mj_setConst
on each mujoco modelOr is there another way to call
mj_setConst
directly on a batched mjx model?