google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
7.49k stars 740 forks source link

Help with MJX speedup #1764

Open jdao913 opened 5 days ago

jdao913 commented 5 days ago

Hello,

I'm a student using MuJoCo for RL research for bipeds/humanoids and wanted some help optimizing my models for use with MJX. I have attached two models that I am trying to use.

I used both the testspeed binary and the handy MJX benchmark script from Gregwar here to compare timings.

I am getting roughly the same sampling time for both the Cassie and Digit models (both around 13-14x realtime on a 3060 and ~93x realtime on a 4090 sampling 2048 model instances) in MJX despite the extra dofs and contraints in the Digit model, which is promising. But it is still less speedup than I was hoping for. For comparison, using CPU MuJoCo I get 24x realtime for Cassie and 14x realtime for Digit using a single thread.

I also tried disabling eulerdamp as per the MJX documentation suggestions, but got some weird behavior. For the Cassie model the sim still runs stably, with a slight slowdown for MJX (~12.5x realtime on a 3060) and a slight speedup for CPU (~27x realtime). For the Digit model it seems like the sim is not stable, I get the QACC warning message in testspeed and massive slow down in MJX (4x realtime on a 3060), which I assume is from the unstable sim. (btw does MJX supress the usual unstable sim warnings? Because I never saw those even when setting solver iterations to 1)

I was hoping some some advice of other changes I can make to my model to speed things up more or if perhaps there are some pitfalls with Jax or CUDA optimizations that I'm running into. I know I can replace the rod equality constraints with tendon constraints and that speeds things up (I've already done that for the CPU version of models we are using), but I'm still waiting for tendons to get MJX support :disappointed:

Perhaps some one with more intuition might be able to give me a better idea of what kind of speed up I should expect with MJX given the complexity of the model, or if the speed up I'm seeing in my 4090 tests are about the limit.

mjx_models.zip

Thanks! Jeremy Dao

btaba commented 5 days ago

Hi @jdao913 , MJX doesn't print the same warnings, but you should make sure the model is stable in CPU MuJoCo first before moving to MJX. If the model is stable on CPU but not in MJX, try fiddling with the solver params, or try jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH); we've seen numerical issues on certain cards get resolved with this precision flag. If you add a from jax import config; config.update("jax_debug_nans", True) you should get warned about any NaN issues while debugging in jax.

In terms of performance, MJX is still heavily bottlenecked by collisions with meshes (we're working on it). Try adding <default> <mesh maxhullvert="64"/> </default> (although I see you're only using primitives?). Try out these fields for contact culling. I can help debug if you get stuck after incorporating the above info. You should consider adding the models to menagerie once you have good performance.

jdao913 commented 5 days ago

Thanks so much for the quick reply! That's good to know about the MJX warnings and Jax NaN warnings.

Yeah our models don't have mesh collisions and don't have meshes at all actually. The contact culling helps quite a lot though, that's good advice! Limiting to 20 contact points/geom pairs gets me +3x speed up right away. Can I ask for a bit of clarification about what the exact meaning of these fields are? Is "max_geom_pairs" just the maximum number of geom pairs that will be considered each iteration? How would this be generated/pruned if more than the max number of pairs exist?

This also made me realize that I don't fully understand the contact/pair field. I was under the impression that having these would build an explicit list of contact pairs to check and that only this list would be checked for contacts; there would be no other dynamically generated pairs to be filtered and checked. However looking at the docs it seems like this is not the case and that I would change the "collision" attribute of option? But this attribute does not seem to exist?

So do I actually have to individually enumerate all of the contact/exclude pairs that I don't want to include/generate? How do I properly define the short list of contact pairs that will be generated before filtering and checking? And then in this case how do the "max_geom_pairs" and "max_contact_points" fields affect things, especially if the generated list is already smaller than the specified maximum?

I also have another question regarding the amount of GPU memory MJX uses. Please pardon my unfamiliarity with Jax, I'm not sure if this is a MJX thing or a Jax thing in general. I noticed that the amount of GPU memory used (or at least allocated) by the MJX python process is always the same no matter how many model instances I use, be it 10 or 4096. Seems like it always tries to take up as much GPU memory as possible without things crashing. How does this work? So in order to choose the number of parallel models to instantiate, it's not however many will fit on to the GPU, it's just when I start to get diminishing returns in sample efficiency?