google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Workarounds for Mujoco Menagerie Models #402

Closed vyeevani closed 8 months ago

vyeevani commented 9 months ago

One of the most attractive things to me about brax is the possibility of highly parallelized sim-to-real applications for robotics. However, one of the issues that I'm facing right now that mujoco-menagerie models require a great deal of handholding and editing be importable into Brax due to specific things not being supported in Brax.

The list of things that I've currently faced are:

It would be great to have a "permissive" mode for mjcf loading in brax.io to handle loading files that may have some surface level violations like the above. (The last one is a bit more tricky but I figure capsule and cylinder aren't terribly far off)

I've got a hacky script to handle this (tested working on unitree go1 from mujoco menagerie):

import xml.etree.ElementTree as ET
from typing import Union
import os

def build_parent_map(tree_root):
    return {c: p for p in tree_root.iter() for c in p}

def belongs_to_collision_class(geom, root, parent_map):
    # Check for explicit class attribute in geom
    if 'class' in geom.attrib and geom.attrib['class'] == 'collision':
        return True

    # Traverse parents to see if any have the class attribute
    parent = parent_map.get(geom)
    while parent is not None:
        if 'class' in parent.attrib and parent.attrib['class'] == 'collision':
            return True
        parent = parent_map.get(parent)

    # Check for default class being "collision"
    default_class_element = root.find("./default/class[@name='collision']")
    if default_class_element is not None and 'class' not in geom.attrib:
        return True

    return False

def cylinder_to_capsule_element(geom):
    # Extract attributes from the original cylinder
    size = geom.attrib.get("size")
    pos = geom.attrib.get("pos", "0 0 0")
    quat = geom.attrib.get("quat", "1 0 0 0")

    # Extract the radius and length from the size
    sizes = [float(x) for x in size.split()]
    if len(sizes) == 1:
        radius = sizes[0]
        length = sizes[0]
    else:
        radius, length = sizes[:2]

    # Create the fromto for the capsule
    half_length = length / 2.0
    pos_values = [float(x) for x in pos.split()]
    fromto = [
      pos_values[0] - half_length, 
      pos_values[1],
      pos_values[2], 
      pos_values[0] + half_length, 
      pos_values[1], 
      pos_values[2],
    ]
    fromto_str = ' '.join(map(str, fromto))

    # Create a new geom element that's a capsule
    capsule = ET.Element('geom', attrib={
        'type': 'capsule',
        'size': str(radius),
        'fromto': fromto_str,
        'quat': quat  # use the same quaternion as the original cylinder
    })

    # Copy over any additional attributes from the original cylinder
    for attr, value in geom.attrib.items():
        if attr not in ['type', 'size', 'pos', 'quat']:
            capsule.attrib[attr] = value

    return capsule

def resolve_includes(xml_filename):
    with open(xml_filename, 'r') as file:
        xml_content = file.read()

    root = ET.fromstring(xml_content)

    for include_elem in root.findall(".//include"):
        include_file_path = os.path.join(os.path.dirname(xml_filename), include_elem.attrib['file'])
        included_content = resolve_includes(include_file_path)  # Recursively resolve any nested includes

        # Insert the content of the included XML into the parent XML
        include_root = ET.fromstring(included_content)
        for child in include_root:
            root.append(child)

        # Remove the <include> element from the parent XML
        root.remove(include_elem)

    return ET.tostring(root, encoding='utf-8', method='xml').decode('utf-8')

def modify_xml_string(xml: str) -> str:
    root = ET.fromstring(xml)
    parent_map = build_parent_map(root)

    # Convert cylinders to capsules based on the mask condition
    for geom in root.findall(".//geom"):
        if (geom.attrib.get('type') == 'cylinder' and 
            belongs_to_collision_class(geom, root, parent_map)):

            capsule = cylinder_to_capsule_element(geom)
            parent = [p for p in root.findall(".//") if geom in p][0]
            parent.remove(geom)
            parent.append(capsule)

    # Remove priority attribute from geom elements
    for geom in root.findall(".//geom[@priority]"):
        del geom.attrib['priority']

    # Modify cone attribute value in <option>
    for option in root.findall(".//option[@cone='elliptic']"):
        option.attrib['cone'] = 'pyramidal'

    return ET.tostring(root, encoding='utf-8', method='xml').decode('utf-8')

def loads(xml_filename: str) -> str:
    with open(xml_filename, 'r') as file:
        xml_content = file.read()

    resolved_xml_content = resolve_includes(xml_filename)

    return modify_xml_string(resolved_xml_content)
btaba commented 9 months ago

Hi @vyeevani , we will have an announcement in the coming weeks that makes brax more interoperable with MuJoCo, stay tuned on the announcements page! We also plan to add "brax-compatible" XMLs to menagerie in the coming weeks/months

JoeMWatson commented 8 months ago

Are you able to comment on how performant these more complex models will be in brax?

I have recently tried simulating the Franka Panda manipulator in Brax by taking the Menagerie XML and making it Brax compliant. While the model loaded successfully, jitting the step and reset functions were incredibly slow (on my M1 Macbook) and I got several warnings such as

********************************
[Compiling module jit_reset] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************

This surprised me as the Google Barkour experimental model JITs quite happily, as do the complex Gym-based environments like the Humanoid.

Do you see there being future performance improvements in this setting, or is it just a consequence of the more complex kinematic chain and collision models?

btaba commented 8 months ago

Hi @JoeMWatson , I'm somewhat surprised Google Barkour JITs happily. Generally, Brax does not scale well with the # of contacts, esp. if they are convex<>convex collisions. I suggest turning off self-collisions for the panda arm (or if you need self-collisions, replacing the meshes with primitives like capsules OR using joint limits in a smarter way). Panda may be more expensive if the meshes are larger.

Check out the fork of Google Barkour here, which only keeps feet<>plane contacts, sufficient for learning joystick policies