google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.34k stars 255 forks source link

TracedConfig does not influence rendering #265

Open peabody124 opened 1 year ago

peabody124 commented 1 year ago

I was trying to extend the domain randomization examples to include changing collider shapes using the code below. However, the renderer doesn't pick up the changes. Specifically this occurs because the mesh information comes from json_format.MessageToDict which does not pick up any of the changes in TracedConfig.

I don't know enough about the calls json_format.MessageToDict is making into the internal structures to know what to override, but if anyone knew and wanted to point me in that direction, I would be happy to make a PR.

def scale_bodies(config, body_scale_dict: dict):
  """Constructs tree and in_axes objects for a joint socket randomizer.
  Adds an offset to any joints that match a key appearing in joint_key.  If
  no joint_key, then does nothing.
  Args:
    env: Environment to randomize over
    body_scale_dict
  Returns:
    Tuple of a pytree containing the randomized values packed into a tree
    structure parseable by the TracedConfig class
  """

  custom_tree = {'bodies': []}

  for b in config.bodies:

    def scale_body(b, x):
      colliders = []
      for c in b.colliders:
        collider = {
          'position': {
            'x': c.position.x * x,
            'y': c.position.y * x,
            'z': c.position.z * x
          },
          'capsule': {
            'length': c.capsule.length * x,
            'radius': c.capsule.radius * x
          }
        }

        colliders.append(collider)
      return {'colliders': colliders}

    if any([key in b.name for key in body_scale_dict.keys()]):
      custom_tree['bodies'].append(scale_body(b, body_scale_dict[b.name]))

    else:
      custom_tree['bodies'].append(scale_body(b, 1.0))

  return custom_tree
peabody124 commented 1 year ago

And simple code to replicate the issue is

from brax.envs import humanoid
env_fn = humanoid.Humanoid
config = env_fn().sys.config
mods = scale_bodies(config, {'right_thigh': 500.2})
updated_config = TracedConfig(config, mods)
json_format.MessageToDict(updated_config)['bodies'][3], updated_config.bodies[3].colliders[0].capsule.length
({'name': 'right_thigh',
  'colliders': [{'position': {'y': 0.005, 'z': -0.17},
    'rotation': {'x': -178.31532},
    'capsule': {'radius': 0.06, 'length': 0.46014702},
    'material': {'friction': 1.0}}],
  'inertia': {'x': 1.0, 'y': 1.0, 'z': 1.0},
  'mass': 4.751751,
  'frozen': {'position': {}, 'rotation': {}}},
 230.16554100513457)

Showing the updated capsule length (230) is not exported from the MessageToDict