google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.23k stars 246 forks source link

Is Brax suitable for sim2real RL and complex terrain? #229

Closed imoneoi closed 1 year ago

imoneoi commented 1 year ago

Is Brax suitable for sim2real RL training, and are there any examples?

Also, can we use complex terrain in simulation? There is heightmap in geometries, however, only collision between heightmap and box is implemented. Does that mean we may use only boxes to compose the agent, and a height map for terrain?

btaba commented 1 year ago

Hi @imoneoi. You are correct, capsule-heightmap is not implemented (feel free to send over an impl), so no complex terrain aside from capsule-mesh, capsule-box, capsule-plane assuming your end-effector is a capsule. But randomizing over other params should get some form of sim2real working. Check out the domain randomization wrapper pushed to experimental https://github.com/google/brax/blob/main/brax/experimental/tracing/wrappers.py, @cdfreeman-google can provide insight if you have questions about the domain randomization

imoneoi commented 1 year ago

Thanks! I have some additional questions about sim2real and domain randomization:

btaba commented 1 year ago
  1. box-heightmap collisions can be used, but box-heightmap (as it's implemented currently) only takes box-corner collisions into account, so you may get some idiosyncrasies in sim that may not be good for transfer to real. We don't currently have a great way to introduce surface irregularities to randomize over in sim. Your best bet is likely a capsule end-effector with a plane or mesh collisions.
  2. The wrapper should vectorize the config (https://github.com/google/brax/blob/main/brax/experimental/tracing/wrappers_test.py#L32-L44) given vectorized params like friction. We are actively working on supporting sim2real so there may be better examples soon.
  3. We can't guarantee that for your domain, randomizing over those parameters would work. But that's sort of the failure mode of domain randomization, the parameters to tune over are domain dependent. Those are definitely good parameters to start off with, and a subset [CoM, mass, friction] might work too!
cdfreeman-google commented 1 year ago

One extra comment on point number 2: the current domain randomization code is written this way, yes, where the randomness is fixed at environment initialization time. It's fairly straightforward to extend this to have the randomness be part of the environment seed (i.e., so that environment randomness is refreshed after every reset). You would simply need to modify the post_process_fn in https://github.com/google/brax/blob/main/brax/experimental/tracing/wrappers.py so that it properly book-keeps the randomness, likely with another auxiliary function for generating fresh custom_trees. e.g., something like: (note, not tested, but this is approximately how I would imagine it going for the reset function):

  def custom_tree_generator(self, rng):
    ...some logic to generate a custom tree from randomness...
    return randomized_custom_tree

  def reset(self, rng):

    def reset_fn(custom_tree):

      def post_process_fn(config):
        config = TracedConfig(config, custom_tree=self.custom_tree_generator(custom_tree['rng']))
        return config

      env = self.env_fn(post_process_fn=post_process_fn)
      return env.reset(custom_tree['rng'])

    self.custom_tree_in_axes[0]['rng'] = 0
    self.custom_tree['rng'] = rng

    return jax.vmap(
        reset_fn, in_axes=self.custom_tree_in_axes)(
            self.custom_tree)
imoneoi commented 1 year ago

Thanks! @btaba, as I see in the code, there may be O(N) complexity for checking collision between a capsule and N triangles, but it's O(1) for box-heightmap. Will it be very slow when representing terrain as a mesh?

I'm considering writing an impl of O(1) complete box-heightmap or capsule-heightmap. Also, the viewer may need an update, it does not support mesh and heightmap.

BTW, I think things like BVH are very hard to implement in JAX. Maybe currently we can only have fast collision detection on heightmaps, not meshes?

btaba commented 1 year ago

@imoneoi yes for roughly the reasons you describe capsule-mesh is slow, which is why we'd want capsule-heightmap to check a subset of the heightmap rather than the whole mesh. I think capsule-heightmap would be of higher priority compared to box-heightmap, but feel free to fix box-heightmap as well! What's likely needed in both is a jittable collision check between a line segment and a subset of triangles in the heightmap.

RE the viewer: what issue are you seeing with mesh and heightmap in the viz? These are supported here https://github.com/google/brax/blob/main/js/system.js#L251-L257

RE BVH: branching logic is difficult to implement in jax. We started looking into mesh-mesh collisions and how to make those fast in jax, but don't have a fleshed out solution. We recently added jax jittable box-box collisions

imoneoi commented 1 year ago

@btaba I'm figuring out a fast jittable capsule-heightmap or box-heightmap collision impl. I think we should identify the local heightmap points and test collision with these. But the fixed shape of JAX may be annoying for finding the local points.

Also, can we avoid vmapping the data of heightmap with domain randomization? (10240 envs 1024 1024 really consumes a lot VRAM)

For the viewer, sorry I missed that line of code, and thought it wasn't supported before.

btaba commented 1 year ago

Hey @imoneoi, sorry for the late reply! You could try creating a bounding box for the capsule with worst-case static size, and check the heightmap against that fixed bounding box at each step

If you're randomizing the heightmap (so not just one heightmap), I'm not sure how to avoid storing those, unless the heightmap heights are generated on the fly using some function, say f(x, y, key) = jax.random.uniform(key+x+y) – so each env gets a different key and the height is generated on the fly

imoneoi commented 1 year ago

Using static size is a good idea. I may also try something like jax.ops.segment_sum to reduce collision cost. For the heightmap, I think dynamic generation may also take up 10240 envs * 1024 * 1024 VRAM, as they happen in the same time. Maybe we can use vmap() without vmapping some arguments.

btaba commented 1 year ago

Yeah using the same heightmap and randomizing start positions may get rid of that memory blow-up, and likely get you similar domain randomized behavior. But I'm not really in the weeds here so deferring to you; a working capsule-heightmap collision func would already be a great win here!

With the implicit function approach, that could take up memory for the worst-case bounding box projected onto the heightmap (depending on the impl and size of things), so likely smaller than 1024x1024

Thanks for taking a look at this this!