google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.25k stars 249 forks source link

Many randomized environments in parallel #338

Closed jonasrothfuss closed 1 year ago

jonasrothfuss commented 1 year ago

I would like to use some form of domain randomization with brax environments. Ideally, I would like re-sample the parameters of the envs in every iteration / after every step. So far I only managed to sample random parameters and replace the default parameters in the environment XML. However, loading the 'system' from XML takes very long and makes this approach very inefficient.

Is there a clean and efficient way to change relevant physics parameters (e.g. body mass, size, friction coef, etc.) without the significant overhead of reconstructing the 'system' from scratch?

btaba commented 1 year ago

Hi @jonasrothfuss

Yes there is a way to sample multiple systems, but we have not polished/released such a thing in v2. Tracing over the System is a high-level intent of the new design of brax. One way to get your project off the ground is to create a vectorized System like so:

def _write_sys(sys, attr, val):
  """"Replaces attributes in sys with val."""
  if not attr:
    return sys
  if len(attr) == 2 and attr[0] == 'geoms':
    geoms = copy.deepcopy(sys.geoms)
    if not hasattr(val, '__iter__'):
      for i, g in enumerate(geoms):
        if not hasattr(g, attr[1]):
          continue
        geoms[i] = g.replace(**{attr[1]: val})
    else:
      sizes = [g.transform.pos.shape[0] for g in geoms]
      g_idx = 0
      for i, g in enumerate(geoms):
        if not hasattr(g, attr[1]):
          continue
        size = sizes[i]
        geoms[i] = g.replace(**{attr[1]: val[g_idx:g_idx + size].T})
        g_idx += size
    return sys.replace(geoms=geoms)
  if len(attr) == 1:
    return sys.replace(**{attr[0]: val})
  return sys.replace(**{attr[0]:
                        _write_sys(getattr(sys, attr[0]), attr[1:], val)})

def set_sys(sys, params: Dict[str, jp.ndarray]):
  """Sets params in the System."""
  for k in params.keys():
    sys = _write_sys(sys, k.split('.'), params[k])
  return sys

def randomize(sys, rng):
  return set_sys(sys, {'link.inertia.mass': sys.link.inertia.mass + jax.random.uniform(rng, shape=(sys.num_links(),))})

rng = jax.random.PRNGKey(0)
rng, *key = jax.random.split(rng, batch_size + 1)
key = jp.reshape(jp.stack(key), (batch_size, 2))
sys_v = jax.vmap(functools.partial(randomize, sys=sys))(key)

I have not fully tested the code, but the intent and rough idea should be there. [1] Create a batch of rng keys, [2] randomize the system over the batch of rng keys, [3] use the vectorized system sys_v to do stuff

We're planning to get a wrapper that does this for you at some point, but still TBD

jonasrothfuss commented 1 year ago

Thanks a lot! The code snippet helped me a lot in getting vectorized random systems to work!

While the masses, friction and elasticity and be easily changed with the set_sys method, it seems that changing the length of links causes problems and changes the system dynamics in unintended ways. I think this is because the transforms (some form of rotation matrices?) depend on the length of the links and get calculated when leading the system XML. Are you aware of any clean workaround to change both the length + transform of a link correctly within the system object without reloading the XML?

btaba commented 1 year ago

Hi @jonasrothfuss , sure here's a code snippet for capsules which I've been using:

def set_sys_capsules(sys, lengths, radii):
  """Sets the system with new capsule lengths/radii."""
  sys2 = set_sys(sys, {'geoms.length': lengths})
  sys2 = set_sys(sys2, {'geoms.radius': radii})

  # we assume inertia.transform.pos is (0,0,0), as is often the case for
  # capsules

  # get the new joint transform
  cur_len = sys.geoms[1].length[:, None]
  joint_dir = jax.vmap(math.normalize)(sys.link.joint.pos)[0]
  joint_dist = sys.link.joint.pos - 0.5 * cur_len * joint_dir
  joint_transform = 0.5 * lengths[:, None] * joint_dir + joint_dist
  sys2 = set_sys(sys2, {'link.joint.pos': joint_transform})

  # get the new link transform
  parent_idx = jp.array([sys.link_parents])
  sys2 = set_sys(
      sys2,
      {
          'link.transform.pos': -(
              sys2.link.joint.pos
              + joint_dist
              + 0.5 * lengths[parent_idx].T * joint_dir
          )
      },
  )
  return sys2

@jax.jit
def randomize_sys_capsules(
    rng: jp.ndarray,
    sys: base.System,
    min_length: float = 0.0,
    max_length: float = 0.0,
    min_radius: float = 0.0,
    max_radius: float = 0.0,
):
  """Randomizes joint offsets, assume capsule geoms appear in geoms[1]."""
  rng, key1, key2 = jax.random.split(rng, 3)
  length_u = jax.random.uniform(
      key1, shape=(sys.num_links(),), minval=min_length, maxval=max_length
  )
  radius_u = jax.random.uniform(
      key2, shape=(sys.num_links(),), minval=min_radius, maxval=max_radius
  )
  length = length_u + sys.geoms[1].length  # pytype: disable=attribute-error
  radius = radius_u + sys.geoms[1].radius  # pytype: disable=attribute-error
  return set_sys_capsules(sys, length, radius)
jc-bao commented 1 year ago

Does anyone have any ideas on how to effectively implement domain randomization in the new v2 environment?

btaba commented 1 year ago

Hi @jc-bao , you can check out the test here for an example of domain randomization in v2:

https://github.com/google/brax/blob/9e14acb7fd2697236eac7ec8df6d94579dab8123/brax/training/agents/ppo/train_test.py#L100-L132

Notice that rand_fn randomizes the System, and the ppo.train routine takes advantage of this function to wrap the environment:

https://github.com/google/brax/blob/9e14acb7fd2697236eac7ec8df6d94579dab8123/brax/envs/wrappers/training.py#L53