google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.16k stars 237 forks source link

domain randomization #148

Open venkatesh-narayan opened 2 years ago

venkatesh-narayan commented 2 years ago

Is it possible to perform domain randomization in Brax? I'd like to change the coefficient of friction of the ground, the inertias & lengths of the links, etc, and train all of them with something like PPO of them in parallel. Judging from some other github issue posts, it seems like it isn't fully possible but I wanted to make sure first.

cdfreeman-google commented 2 years ago

Hello!

This is not yet supported out of the box, but there are a couple of workarounds. Probably by far the easiest workaround (call it Option 1) is to simply run many different systems asynchronously (as suggested by Erik here: https://github.com/google/brax/issues/143), but this loses out on a lot of the performance we baked into our algorithms which assume everything is happening on device in parallel.

Some other workarounds: Option 2: Make your system contain multiple independent variants of the problem you care about.

This would require some data reshaping because our PPO algorithm currently assumes data is independent across batches, but instead of [128, #-bodies, 3] - sized data, you could do, e.g., [4, 32*#-bodies, 3] - sized data, with 32 independent "scenes" per system. Again, this would require rewriting a bit of code that assumes batches are independent. It would also fix the randomization to whatever you set it to in those "32" individual scenes (which is maybe not ideal if you're imagining doing an adaptive curriculum).

Option 3a: Add extra data fields for data that varies across batch. Right now, the only data that can be different across batches is stored in either QPs, or is set by randomness (e.g., different target locations for locomotion in Fetch). You could add an extra dynamic_length field to QP that is read in on every step that directly controls something like the length of a part, but you'd have to be careful to make sure this object is injected everywhere that length is read in the code.

Brax is currently organized to read this data once at system initialization, whereupon it builds auxiliary data structures that it then uses during simulation. For data that's only used in a couple of places, this would be easy, and not require too much plumbing of args around. Other data can be a bit trickier--masses are used all over the place, so you'd have to hunt all of those down to do this correctly.

Option 3b: A slightly more general version of this is to add plumbing for rng to the System step, and then use this RNG to sample those data values in the course of simulation. It would require the same amount of plumbing as in 3a, but there'd only be one extra data field for the QP (an rng field probably), and the actual selection of which "length" or whatever could happen in code.

Let me know if that doesn't make sense!