google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.25k stars 249 forks source link

Bug when I create an fixed arena and set conaffinity=contype=1 #355

Closed queenxy closed 1 year ago

queenxy commented 1 year ago

I want to create a robot soccer env in braxv2, but got some error. The error occurs because I set conaffinity=contype=1 for the arena, when I set them to 0, the error will disappear, but the ball will get through the arena. This is my xml:

<mujoco model="robo_soccer">
    <compiler angle="radian" inertiafromgeom="true"/>
    <default>
        <joint armature="1" damping="1" limited="true"/>
        <geom conaffinity="1" contype="1" friction="0.1 0.1 0.1"/>
    </default>
    <!-- Removed RK4 integrator for brax. -->
    <option gravity="0 0 0" timestep="0.01" />

    <worldbody>
        <!-- Arena -->
        <geom name="ground" pos="0 0 0" size="1 1 10" type="plane"/>
        <geom fromto="-0.75 -0.65 0.01 0.75 -0.65 0.02135" name="sideS" size=".03" type="capsule"/>
        <geom fromto="-0.75 -0.65 0.01 -0.75 -0.2 0.02135" name="sideE1" size=".03" type="capsule"/>
        <geom fromto="-0.75 0.65 0.01 -0.75 0.2 0.02135" name="sideE2" size=".03" type="capsule"/>
        <geom fromto="-0.75 0.65 0.01 0.75 0.65 0.02135" name="sideN" size=".03" type="capsule"/>
        <geom fromto="0.75 -0.65 0.01 0.75 -0.2 0.02135" name="sideW1" size=".03" type="capsule"/>
        <geom fromto="0.75 0.65 0.01 0.75 0.2 0.02135" name="sideW2" size=".03" type="capsule"/>
        <geom fromto="0.85 -0.2 0.01 0.85 0.2 0.02135" name="sideG1" size=".03" type="capsule"/>
        <geom fromto="-0.85 -0.2 0.01 -0.85 0.2 0.02135" name="sideG2" size=".03" type="capsule"/>
        <geom fromto="-0.75 -0.2 0.01 -0.85 -0.2 0.02135" name="side1" size=".03" type="capsule"/>
        <geom fromto="-0.75 0.2 0.01 -0.85 0.2 0.02135" name="side2" size=".03" type="capsule"/>
        <geom fromto="0.75 -0.2 0.01 0.85 -0.2 0.02135" name="side3" size=".03" type="capsule"/>
        <geom fromto="0.75 0.2 0.01 0.85 0.2 0.02135" name="side4" size=".03" type="capsule"/>
        <!-- Ball -->
        <body name="ball" pos="0 0 0.02135">
            <geom contype="1" conaffinity="1" name="ball" pos="0 0 0" size="0.02135" type="sphere" mass="0.046"/>
            <joint name="ball_slidey" type="slide" pos="0 0 0" axis="0 1 0" range="-0.65 0.65" damping="0.05"/>
            <joint name="ball_slidex" type="slide" pos="0 0 0" axis="1 0 0" range="-0.85 0.85" damping="0.05"/>
        </body>
        <!-- Agent -->
        <body name="B1" pos="-0.5 0 0.03">
            <geom conaffinity="1" contype="1" name="agent_b1" fromto="0 0 -0.04 0 0 0.05" size="0.04" type="capsule" mass="1"/>
            <joint name="b1_slidey" type="slide" pos="0 0 0" axis="0 1 0" range="-0.65 0.65" damping="0.05"/>
            <joint name="b1_slidex" type="slide" pos="0 0 0" axis="1 0 0" range="-0.25 1.25" damping="0.05"/>
        </body>
    </worldbody>
    <actuator>
        <motor ctrllimited="true" ctrlrange="-1 1" gear="0.1" joint="b1_slidex" name="b1_slidex"/>
        <motor ctrllimited="true" ctrlrange="-1 1" gear="0.1" joint="b1_slidey" name="b1_slidey"/>
    </actuator>
</mujoco>

And this is the error i got:

Traceback (most recent call last):
  File "/home/qxy/brax_soccer/soccer_braxv2/v2test.py", line 27, in <module>
    state = jit_env_step(state, act)
  File "/home/qxy/brax_soccer/soccer_braxv2/fieldv2.py", line 99, in step
    pipeline_state = self.pipeline_step(pipeline_state0, action)
  File "/home/qxy/miniconda3/envs/braxv2/lib/python3.10/site-packages/brax/envs/env.py", line 127, in pipeline_step
    return jax.lax.scan(f, pipeline_state, (), self._n_frames)[0]
  File "/home/qxy/miniconda3/envs/braxv2/lib/python3.10/site-packages/brax/envs/env.py", line 123, in f
    self._pipeline.step(self.sys, state, action, self._debug),
  File "/home/qxy/miniconda3/envs/braxv2/lib/python3.10/site-packages/brax/positional/pipeline.py", line 88, in step
    contact = geometry.contact(sys, x)
  File "/home/qxy/miniconda3/envs/braxv2/lib/python3.10/site-packages/brax/geometry/contact.py", line 499, in contact
    tx_i = x.take(geom_i.link_idx).vmap().do(geom_i.transform)
  File "/home/qxy/miniconda3/envs/braxv2/lib/python3.10/site-packages/brax/base.py", line 61, in take
    return tree_map(lambda x: jp.take(x, i, axis=axis, mode='wrap'), self)
  File "/home/qxy/miniconda3/envs/braxv2/lib/python3.10/site-packages/brax/base.py", line 61, in <lambda>
    return tree_map(lambda x: jp.take(x, i, axis=axis, mode='wrap'), self)
  File "/home/qxy/miniconda3/envs/braxv2/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3691, in take
    return _take(a, indices, None if axis is None else operator.index(axis), out,
  File "/home/qxy/miniconda3/envs/braxv2/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3700, in _take
    util.check_arraylike("take", a, indices)
  File "/home/qxy/miniconda3/envs/braxv2/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 343, in check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: take requires ndarray or scalar arguments, got <class 'NoneType'> at position 1.
queenxy commented 1 year ago

The error occurs in line 499 in brax.geometry.contact.py

 tx_i = x.take(geom_i.link_idx).vmap().do(geom_i.transform)

the link_idx is None for all the arena. This is caused in brax.io.mjcf.py, line 348

'link_idx': mj.geom_bodyid[i] - 1 if mj.geom_bodyid[i] > 0 else None,
queenxy commented 1 year ago

so i wonder how can i fix the arena on the ground and turn on the collision detection at the same time.

btaba commented 1 year ago

Hi @queenxy, thanks for using Brax! For the arena, I believe you'd want contype="1" conaffinity="0" since you don't want the arena colliding with itself We'll create a fix this so that it doesn't crash, world geoms are static and wouldn't be colliding with each other anyways