Closed imoneoi closed 1 year ago
Hi @imoneoi. You are correct, capsule-heightmap is not implemented (feel free to send over an impl), so no complex terrain aside from capsule-mesh, capsule-box, capsule-plane assuming your end-effector is a capsule. But randomizing over other params should get some form of sim2real working. Check out the domain randomization wrapper pushed to experimental https://github.com/google/brax/blob/main/brax/experimental/tracing/wrappers.py, @cdfreeman-google can provide insight if you have questions about the domain randomization
Thanks! I have some additional questions about sim2real and domain randomization:
One extra comment on point number 2: the current domain randomization code is written this way, yes, where the randomness is fixed at environment initialization time. It's fairly straightforward to extend this to have the randomness be part of the environment seed (i.e., so that environment randomness is refreshed after every reset). You would simply need to modify the post_process_fn
in https://github.com/google/brax/blob/main/brax/experimental/tracing/wrappers.py so that it properly book-keeps the randomness, likely with another auxiliary function for generating fresh custom_trees
. e.g., something like:
(note, not tested, but this is approximately how I would imagine it going for the reset function):
def custom_tree_generator(self, rng):
...some logic to generate a custom tree from randomness...
return randomized_custom_tree
def reset(self, rng):
def reset_fn(custom_tree):
def post_process_fn(config):
config = TracedConfig(config, custom_tree=self.custom_tree_generator(custom_tree['rng']))
return config
env = self.env_fn(post_process_fn=post_process_fn)
return env.reset(custom_tree['rng'])
self.custom_tree_in_axes[0]['rng'] = 0
self.custom_tree['rng'] = rng
return jax.vmap(
reset_fn, in_axes=self.custom_tree_in_axes)(
self.custom_tree)
Thanks! @btaba, as I see in the code, there may be O(N)
complexity for checking collision between a capsule and N
triangles, but it's O(1)
for box-heightmap. Will it be very slow when representing terrain as a mesh?
I'm considering writing an impl of O(1)
complete box-heightmap or capsule-heightmap. Also, the viewer may need an update, it does not support mesh and heightmap.
BTW, I think things like BVH are very hard to implement in JAX. Maybe currently we can only have fast collision detection on heightmaps, not meshes?
@imoneoi yes for roughly the reasons you describe capsule-mesh is slow, which is why we'd want capsule-heightmap to check a subset of the heightmap rather than the whole mesh. I think capsule-heightmap would be of higher priority compared to box-heightmap, but feel free to fix box-heightmap as well! What's likely needed in both is a jittable collision check between a line segment and a subset of triangles in the heightmap.
RE the viewer: what issue are you seeing with mesh and heightmap in the viz? These are supported here https://github.com/google/brax/blob/main/js/system.js#L251-L257
RE BVH: branching logic is difficult to implement in jax. We started looking into mesh-mesh collisions and how to make those fast in jax, but don't have a fleshed out solution. We recently added jax jittable box-box collisions
@btaba I'm figuring out a fast jittable capsule-heightmap or box-heightmap collision impl. I think we should identify the local heightmap points and test collision with these. But the fixed shape of JAX may be annoying for finding the local points.
Also, can we avoid vmapping the data of heightmap with domain randomization? (10240 envs 1024 1024 really consumes a lot VRAM)
For the viewer, sorry I missed that line of code, and thought it wasn't supported before.
Hey @imoneoi, sorry for the late reply! You could try creating a bounding box for the capsule with worst-case static size, and check the heightmap against that fixed bounding box at each step
If you're randomizing the heightmap (so not just one heightmap), I'm not sure how to avoid storing those, unless the heightmap heights are generated on the fly using some function, say f(x, y, key) = jax.random.uniform(key+x+y) – so each env gets a different key and the height is generated on the fly
Using static size is a good idea. I may also try something like jax.ops.segment_sum
to reduce collision cost. For the heightmap, I think dynamic generation may also take up 10240 envs * 1024 * 1024
VRAM, as they happen in the same time. Maybe we can use vmap() without vmapping some arguments.
Yeah using the same heightmap and randomizing start positions may get rid of that memory blow-up, and likely get you similar domain randomized behavior. But I'm not really in the weeds here so deferring to you; a working capsule-heightmap collision func would already be a great win here!
With the implicit function approach, that could take up memory for the worst-case bounding box projected onto the heightmap (depending on the impl and size of things), so likely smaller than 1024x1024
Thanks for taking a look at this this!
Is Brax suitable for sim2real RL training, and are there any examples?
Also, can we use complex terrain in simulation? There is heightmap in geometries, however, only collision between heightmap and box is implemented. Does that mean we may use only boxes to compose the agent, and a height map for terrain?