google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Geom fromto attribute not supported on fused bodies #359

Closed Jgorin closed 1 year ago

Jgorin commented 1 year ago

I'm trying to load a mujoco model with a single body and a capsule geometry with a fromto specified but I'm running into an error:

Traceback (most recent call last):
  File "/Users/josh/InfiniteSky/hermes/rl/src/isky_rl_algorithms/common/env_wrappers/brax_test.py", line 43, in <module>
    sys = mjcf.load(model_path)
  File "/Users/josh/InfiniteSky/brax/brax/io/mjcf.py", line 530, in load
    mj = mujoco.MjModel.from_xml_string(xml, assets=assets)
ValueError: Error: both pos and fromto defined in geom 'pelvis_geom' (id = 0)
Object name = pelvis_geom, id = 0, line = 22, column = -1

Here is the model I'm trying to load:

<mujoco model="humanoid">
    <compiler angle="degree" inertiafromgeom="true" />
    <default>
        <joint limited="true" />
        <geom rgba=".8 .6 .4 1" condim="4" />
        <default class="damped_joint">
            <joint damping="2" stiffness="2.0" armature="0.02" />
        </default>
    </default>
    <option timestep="0.0008" iterations="10" solver="PGS" collision="all">
        <flag energy="enable" />
    </option>
    <asset>
        <texture type="skybox" builtin="gradient" rgb1="1 1 1" rgb2=".6 .8 1" width="256" height="256" />
        <texture name="texplane" type="2d" builtin="checker" rgb1=".2 .3 .4" rgb2=".1 0.15 0.2" width="512" height="512" />

        <material name="MatPlane" reflectance="0.5" texture="texplane" texrepeat="1 1" texuniform="true" />
    </asset>
    <worldbody>

        <body name="pelvis" pos="0.385 0.32 1.26918883">
            <!-- <joint name="lwaist_x" pos="0 0 0" axis="0 0 1" range="0 1" type="hinge" /> -->
            <site name="pelvis" pos="0 0 0" />
            <geom fromto="-0.081889035416171 0.009999999999999983 0.0 0.081889035416171 0.010000000000000018 0.0" name="pelvis_geom" size="0.07" mass="10.928" type="capsule" />
        </body>

    </worldbody>

    <actuator>
        <!-- <motor gear="1" joint="lwaist_x" name="lwaist_x" /> -->
    </actuator>

</mujoco>

I think the error is coming from _fuse_bodies in brax/io/mjcf.py. It looks like the function doesn't cover the case where a fused body contains geometry that specifies a fromto.

I'm not sure what the best way to cover this use case is, but here's what I changed to at least get it to parse the xml:

def _fuse_bodies(elem: ElementTree.Element):
  """Fuses together parent child bodies that have no joint."""

  for child in list(elem):  # we will modify elem children, so make a copy
    _fuse_bodies(child)
    # this only applies to bodies with no joints
    if child.tag != 'body':
      continue
    if child.find('joint') is not None or child.find('freejoint') is not None:
      continue
    cpos = child.attrib.get('pos', '0 0 0')
    cpos = np.fromstring(cpos, sep=' ')
    cquat = child.attrib.get('quat', '1 0 0 0')
    cquat = np.fromstring(cquat, sep=' ')
    for grandchild in child:
      # TODO: might need to offset more than just body, geom
      if grandchild.tag in ('body', 'geom') and (cpos != 0).any():
        gcfromto = grandchild.attrib.get('fromto', None)
        if gcfromto:
            from_pos = np.fromstring(' '.join(gcfromto.split(' ')[0:3]), sep=' ')
            to_pos = np.fromstring(' '.join(gcfromto.split(' ')[3:6]), sep=' ')
            from_pos, _ = _transform_do(cpos, cquat, from_pos, np.array([1, 0, 0, 0]))
            to_pos, _ = _transform_do(cpos, cquat, to_pos, np.array([1, 0, 0, 0]))
            gcfromto = ' '.join('%f' % i for i in np.concatenate([from_pos, to_pos]))
            grandchild.attrib['fromto'] = gcfromto
        else:
            gcpos = grandchild.attrib.get('pos', '0 0 0')
            gcquat = grandchild.attrib.get('quat', '1 0 0 0')
            gcpos = np.fromstring(gcpos, sep=' ')
            gcquat = np.fromstring(gcquat, sep=' ')
            gcpos, gcquat = _transform_do(cpos, cquat, gcpos, gcquat)
            gcpos = ' '.join('%f' % i for i in gcpos)
            gcquat = ' '.join('%f' % i for i in gcquat)
            grandchild.attrib['pos'] = gcpos
            grandchild.attrib['quat'] = gcquat

      elem.append(grandchild)
    elem.remove(child)

But after I changed this I started getting errors when trying to set the state of the system (this is using a model that does contain joints to set the state of).

Traceback (most recent call last):
  File "/Users/josh/InfiniteSky/hermes/rl/src/isky_rl_algorithms/common/env_wrappers/brax_test.py", line 45, in <module>
    env.reset()
  File "/Users/josh/InfiniteSky/hermes/rl/src/isky_rl_algorithms/common/env_wrappers/brax_test.py", line 29, in reset
    pipeline_state = self.pipeline_init(qpos, qvel)
  File "/Users/josh/InfiniteSky/brax/brax/envs/base.py", line 114, in pipeline_init
    return self._pipeline.init(self.sys, q, qd, self._debug)
  File "/Users/josh/InfiniteSky/brax/brax/generalized/pipeline.py", line 48, in init
    state = constraint.jacobian(sys, state)
  File "/Users/josh/InfiniteSky/brax/brax/generalized/constraint.py", line 183, in jacobian
    jpds = jac_contact(sys, state), jac_limit(sys, state)
  File "/Users/josh/InfiniteSky/brax/brax/generalized/constraint.py", line 142, in jac_contact
    contact = geometry.contact(sys, state.x)
  File "/Users/josh/InfiniteSky/brax/brax/geometry/contact.py", line 499, in contact
    tx_i = x.take(geom_i.link_idx).vmap().do(geom_i.transform)
  File "/Users/josh/InfiniteSky/brax/brax/base.py", line 61, in take
    return tree_map(lambda x: jp.take(x, i, axis=axis, mode='wrap'), self)
  File "/Users/josh/InfiniteSky/brax/brax/base.py", line 61, in <lambda>
    return tree_map(lambda x: jp.take(x, i, axis=axis, mode='wrap'), self)
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3691, in take
    return _take(a, indices, None if axis is None else operator.index(axis), out,
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/pjit.py", line 238, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/pjit.py", line 180, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/api.py", line 311, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/pjit.py", line 480, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/pjit.py", line 918, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/linear_util.py", line 322, in memoized_fun
    ans = call(fun, *args)
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/pjit.py", line 874, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2049, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 2066, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3700, in _take
    util.check_arraylike("take", a, indices)
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 343, in check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: take requires ndarray or scalar arguments, got <class 'NoneType'> at position 1.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/josh/InfiniteSky/hermes/rl/src/isky_rl_algorithms/common/env_wrappers/brax_test.py", line 45, in <module>
    env.reset()
  File "/Users/josh/InfiniteSky/hermes/rl/src/isky_rl_algorithms/common/env_wrappers/brax_test.py", line 29, in reset
    pipeline_state = self.pipeline_init(qpos, qvel)
  File "/Users/josh/InfiniteSky/brax/brax/envs/base.py", line 114, in pipeline_init
    return self._pipeline.init(self.sys, q, qd, self._debug)
  File "/Users/josh/InfiniteSky/brax/brax/generalized/pipeline.py", line 48, in init
    state = constraint.jacobian(sys, state)
  File "/Users/josh/InfiniteSky/brax/brax/generalized/constraint.py", line 183, in jacobian
    jpds = jac_contact(sys, state), jac_limit(sys, state)
  File "/Users/josh/InfiniteSky/brax/brax/generalized/constraint.py", line 142, in jac_contact
    contact = geometry.contact(sys, state.x)
  File "/Users/josh/InfiniteSky/brax/brax/geometry/contact.py", line 499, in contact
    tx_i = x.take(geom_i.link_idx).vmap().do(geom_i.transform)
  File "/Users/josh/InfiniteSky/brax/brax/base.py", line 61, in take
    return tree_map(lambda x: jp.take(x, i, axis=axis, mode='wrap'), self)
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/tree_util.py", line 210, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/josh/InfiniteSky/brax/brax/base.py", line 61, in <lambda>
    return tree_map(lambda x: jp.take(x, i, axis=axis, mode='wrap'), self)
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3691, in take
    return _take(a, indices, None if axis is None else operator.index(axis), out,
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3700, in _take
    util.check_arraylike("take", a, indices)
  File "/Users/josh/InfiniteSky/hermes/py_env/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 343, in check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: take requires ndarray or scalar arguments, got <class 'NoneType'> at position 1.

The way I'm calling pipeline_init is as follows:

    qpos = self.sys.init_q + jax.random.uniform(
        rng1, (self.sys.q_size(),), minval=-0.01, maxval=0.01
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.qd_size(),), minval=-0.01, maxval=0.01
    )
    pipeline_state = self.pipeline_init(qpos, qvel)
btaba commented 1 year ago

Hi @Jgorin thanks for using Brax! For the first bug, thanks for the fix! Indeed we should add the case where a fused body doesn't use pos/rot but fromto. We'll push your fix shortly. The second bug is in the geom pair logic, which we'll also push a fix for