Open Velythyl opened 4 months ago
Hey @Velythyl I've been looking forward to this feature for a while now, thanks a lot for sharing this! I'm just curious, why did you close the PR?
@lebrice Hey! Sorry, I realized I had some cleanup to do, and it was way past 5pm so I wanted to go home. I reopened it now.
Thanks @Velythyl ! The recent comment made me just realize that maintainers hadn't commented on the PR. There were a few design decisions that went into DomainRandomizationVmapWrapper:
sys
was not added as part of State
randomization_fn
.The cons of the impl at HEAD are that:
randomization_fn
What I think would make sense to merge, is to add a wrapper with the same API as DomainRandomizationVmapWrapper
, that passes in_axes
and the randomized Sys PyTree values in the State
, as discussed in this thread: https://github.com/google/brax/issues/446 .
Hello!
For an unrelated research project, I needed a massively parallel RL environment with domain randomization capabilities. Isaac Sim/Gym/Omniverse fit the bill, but I also needed the simulator to be differentiable w.r.t. each domain randomization parameters.
So I set out to implement DR in brax. This is research code, so it's obviously a little janky and ad-hoc. But I thought maybe the brax community could find this interesting, and perhaps (with a lot of tuning) even merge it into brax main.
Special thanks to this github issue from which I stole some code ;) here
Note that this domain randomization method is more powerful than this. With this code, we can randomize every single simulation step, if we so wish.
The summary of the implementation is simple: we just augment the simulation state to contain
sys
, thereby allowing every single parallel environment access to its own separatesys
. Also, this enables us to resamplesys
according to some rule (for example, "resample every 50 steps").Features:
The vsys wrapper allows for a vectorized
sys
variable that might contain different domain randomization values for each vectorized envDomain randomization is controlled via a simple yaml file format that describes the path to a domain randomization target. Example:
This randomizes over the 7 links of the robot. For the mass, the base is "r", so the value is "read" from the default value defined in the URDF file. The min-max ranges are both relative to the base, so the current setup randomizes from [r-0.5, r+0.5]. For the damping, no base is given, which defaults to "r". One could also set the base to a float value. Another possible value for the base is "n" ("none"), which disables randomization for this index.
Domain randomization is differentiable (!)
For example, running a simple optax optimizer, we can obtain the true domain randomization parameters in play for a specific timestep.
Known issues:
sys
to be included in the state, a few of the python type hints are brokenif __name__ == "__main__":
function. Specifically, here: https://github.com/Velythyl/brax/blob/b6cab6449ba677108e37739286e0521f7c226a9e/brax/envs/wrappers/vsys.py#L553Again, I don't expect this to be merged as-is. But perhaps the implementation might be interesting to the community, hence the reason for this PR.