google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

[FEATURE] Vsys feature: massively parallel domain randomization #458

Open Velythyl opened 4 months ago

Velythyl commented 4 months ago

Hello!

For an unrelated research project, I needed a massively parallel RL environment with domain randomization capabilities. Isaac Sim/Gym/Omniverse fit the bill, but I also needed the simulator to be differentiable w.r.t. each domain randomization parameters.

So I set out to implement DR in brax. This is research code, so it's obviously a little janky and ad-hoc. But I thought maybe the brax community could find this interesting, and perhaps (with a lot of tuning) even merge it into brax main.

Special thanks to this github issue from which I stole some code ;) here

Note that this domain randomization method is more powerful than this. With this code, we can randomize every single simulation step, if we so wish.

The summary of the implementation is simple: we just augment the simulation state to contain sys, thereby allowing every single parallel environment access to its own separate sys. Also, this enables us to resample sys according to some rule (for example, "resample every 50 steps").


Features:

The vsys wrapper allows for a vectorized sys variable that might contain different domain randomization values for each vectorized env
Domain randomization is controlled via a simple yaml file format that describes the path to a domain randomization target. Example:
link:
  inertia:
    mass:
      base: [r, r, r, r, r, r, r]
      min: [-0.5, -0.5,-0.5,-0.5,-0.5,-0.5,-0.5]
      max: [0.5, 0.5,0.5,0.5,0.5,0.5,0.5]
  constraint_ang_damping:
    min: [-1,-1,-1,1,1,1,1]
    max: [2,2,2,1.5,1,1,1]

This randomizes over the 7 links of the robot. For the mass, the base is "r", so the value is "read" from the default value defined in the URDF file. The min-max ranges are both relative to the base, so the current setup randomizes from [r-0.5, r+0.5]. For the damping, no base is given, which defaults to "r". One could also set the base to a float value. Another possible value for the base is "n" ("none"), which disables randomization for this index.

Domain randomization is differentiable (!)

For example, running a simple optax optimizer, we can obtain the true domain randomization parameters in play for a specific timestep.

Known issues:


Again, I don't expect this to be merged as-is. But perhaps the implementation might be interesting to the community, hence the reason for this PR.

lebrice commented 4 months ago

Hey @Velythyl I've been looking forward to this feature for a while now, thanks a lot for sharing this! I'm just curious, why did you close the PR?

Velythyl commented 4 months ago

@lebrice Hey! Sorry, I realized I had some cleanup to do, and it was way past 5pm so I wanted to go home. I reopened it now.

btaba commented 2 months ago

Thanks @Velythyl ! The recent comment made me just realize that maintainers hadn't commented on the PR. There were a few design decisions that went into DomainRandomizationVmapWrapper:

  1. We saw better performance when sys was not added as part of State
  2. We wanted the user to fully define the randomization strategy rather than have a schema. At HEAD, this can be done via the randomization_fn.

The cons of the impl at HEAD are that:

  1. The reset is static and stored in the wrapper, as addressed in this PR.
  2. Simple randomization strategies still require the user to write a randomization_fn

What I think would make sense to merge, is to add a wrapper with the same API as DomainRandomizationVmapWrapper, that passes in_axes and the randomized Sys PyTree values in the State, as discussed in this thread: https://github.com/google/brax/issues/446 .