google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

How many nested bodies do brax support?(w.r.t rope and wire ) #386

Open baixianger opened 10 months ago

baixianger commented 10 months ago

I implemented my rope with nested rigid bodies and joints, and I did some experiments on Brax. I found that Brax only supports limited nested bodies.

Would you happen to have any solution for a rope?

<mujoco model="panda nohand">
  <compiler angle="radian" meshdir="assets" autolimits="true"/>

  <default>
    <default class="panda">
      <material specular="0.5" shininess="0.25"/>
      <joint armature="0.1" damping="1" axis="0 0 1" range="-2.8973 2.8973" type="hinge"/>
      <general dyntype="none" biastype="affine" ctrlrange="-2.8973 2.8973" forcerange="-87 87"/>
      <default class="finger">
        <joint axis="0 1 0" type="slide" range="0 0.04"/>
      </default>

      <default class="panda/visual">
        <geom type="mesh" contype="0" conaffinity="0" group="2" margin="1.0" gap="1.0"/>
      </default>
      <default class="panda/collision">
        <geom type="mesh" group="3" margin="1.0" gap="1.0"/>
      </default>
      <site size="0.001" rgba="0.5 0.5 0.5 0.3" group="4"/>
    </default>

    <default class="whip">
      <default class="X">
        <joint type="hinge" axis="1 0 0" pos="0 0 0" springref="0" stiffness="0.242" damping="0.092" /> 
      </default>
      <default class="Y">
        <joint type="hinge" axis="0 1 0" pos="0 0 0" springref="0" stiffness="0.242" damping="0.092" /> 
      </default>
      <default class="Z">
        <geom type="capsule" material="white" fromto="0 0 0 0 0 0.03" size="0.006" mass="0.012" />
      </default>
    </default>  

    <default class="target">
      <geom rgba=".5 .5 .5 .4" group="3"/>
    </default>
  </default>

  <custom>
      <!-- brax custom params -->
        <!-- <numeric data="1.3e-06 0.66 -0.00015 -0.073 0.058 2.4 2.9 -0.16  0.76 -0.12
                   0.41  -0.07  0.21  -0.037 0.11  -0.02  0.059  -0.01  0.032 -0.0056
                   0.017 -0.0031 0.0099 -0.0018 0.0059 -0.0011 0.0038 -0.00075 0.0026  -0.00057
                   0.002 -0.00049 0.0016  -0.00045 0.0015  -0.00043 0.0014 -0.00043 0.0013 -0.00043 
                   0.0012 -0.00043 0.0012 -0.00043 0.0011 -0.00041 0.00099 -0.00038 0.00088 -0.00034 
                   0.00075 -0.00029 0.0006 -0.00023 0.00045 -0.00016 0.0003 -9.2e-05 0.00017 -3.5e-05 
                   6.1e-05 0 0  -0.9 -5.997 0.022 1 0 0 0" 
              name="init_qpos"/> -->
  </custom>

  <asset>
    <texture builtin="gradient" height="100" rgb1=".4 .5 .6" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>

    <material class="panda" name="white" rgba="1 1 1 1"/>
    <material class="panda" name="off_white" rgba="0.901961 0.921569 0.929412 1"/>
    <material class="panda" name="dark_grey" rgba="0.25 0.25 0.25 1"/>
    <material class="panda" name="green" rgba="0 1 0 1"/>
    <material class="panda" name="light_blue" rgba="0.039216 0.541176 0.780392 1"/>

    <!-- Collision meshes -->
    <mesh name="link0_c" file="link0.stl"/>
    <mesh name="link1_c" file="link1.stl"/>
    <mesh name="link2_c" file="link2.stl"/>
    <mesh name="link3_c" file="link3.stl"/>
    <mesh name="link4_c" file="link4.stl"/>
    <mesh name="link5_c0" file="link5_collision_0.obj"/>
    <mesh name="link5_c1" file="link5_collision_1.obj"/>
    <mesh name="link5_c2" file="link5_collision_2.obj"/>
    <mesh name="link6_c" file="link6.stl"/>
    <mesh name="link7_c" file="link7.stl"/>

  </asset>

  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom material="MatPlane" name="floor" pos="0 0 0" size="20 20 0.125" type="plane"/>
    <body name="link0" childclass="panda">
      <joint name="dummy0" axis="0 1 0" type="hinge" range="-1e-6 1e-6" pos="0 0 0"/>
      <joint name="dummy1" axis="1 0 0" type="hinge" range="-1e-6 1e-6" pos="0 0 0"/>
      <inertial mass="0.629769" pos="-0.041018 -0.00014 0.049974"
        fullinertia="0.00315 0.00388 0.004285 8.2904e-7 0.00015 8.2299e-6"/>
      <geom mesh="link0_c" class="panda/collision"/>
      <body name="link1" pos="0 0 0.333">
        <inertial mass="4.970684" pos="0.003875 0.002081 -0.04762"
          fullinertia="0.70337 0.70661 0.0091170 -0.00013900 0.0067720 0.019169"/>
        <joint name="joint1"/>
        <geom mesh="link1_c" class="panda/collision"/>
        <body name="link2" quat="1 -1 0 0">
          <inertial mass="0.646926" pos="-0.003141 -0.02872 0.003495"
            fullinertia="0.0079620 2.8110e-2 2.5995e-2 -3.925e-3 1.0254e-2 7.04e-4"/>
          <joint name="joint2" range="-1.7628 1.7628"/>
          <geom mesh="link2_c" class="panda/collision"/>
          <body name="link3" pos="0 -0.316 0" quat="1 1 0 0">
            <joint name="joint3"/>
            <inertial mass="3.228604" pos="2.7518e-2 3.9252e-2 -6.6502e-2"
              fullinertia="3.7242e-2 3.6155e-2 1.083e-2 -4.761e-3 -1.1396e-2 -1.2805e-2"/>
            <geom mesh="link3_c" class="panda/collision"/>
            <body name="link4" pos="0.0825 0 0" quat="1 1 0 0">
              <inertial mass="3.587895" pos="-5.317e-2 1.04419e-1 2.7454e-2"
                fullinertia="2.5853e-2 1.9552e-2 2.8323e-2 7.796e-3 -1.332e-3 8.641e-3"/>
              <joint name="joint4" range="-3.0718 -0.0698"/>
              <geom mesh="link4_c" class="panda/collision"/>
              <body name="link5" pos="-0.0825 0.384 0" quat="1 -1 0 0">
                <inertial mass="1.225946" pos="-1.1953e-2 4.1065e-2 -3.8437e-2"
                  fullinertia="3.5549e-2 2.9474e-2 8.627e-3 -2.117e-3 -4.037e-3 2.29e-4"/>
                <joint name="joint5"/>
                <geom mesh="link5_c0" class="panda/collision"/>
                <geom mesh="link5_c1" class="panda/collision"/>
                <geom mesh="link5_c2" class="panda/collision"/>
                <body name="link6" quat="1 1 0 0">
                  <inertial mass="1.666555" pos="6.0149e-2 -1.4117e-2 -1.0517e-2"
                    fullinertia="1.964e-3 4.354e-3 5.433e-3 1.09e-4 -1.158e-3 3.41e-4"/>
                  <joint name="joint6" range="-0.0175 3.7525"/>
                  <geom mesh="link6_c" class="panda/collision"/>
                  <body name="link7" pos="0.088 0 0" quat="1 1 0 0">
                    <inertial mass="7.35522e-01" pos="1.0517e-2 -4.252e-3 6.1597e-2"
                      fullinertia="1.2516e-2 1.0027e-2 4.815e-3 -4.28e-4 -1.196e-3 -7.41e-4"/>
                    <joint name="joint7"/>
                    <geom mesh="link7_c" class="panda/collision"/>
                    <!-- whip begin-->
                    <geom type="sphere" pos="0 0 0.107" size="0.045" material="white"/>
                      <body name="B0" pos="0 0 0.152">
                        <joint class="X"/>
                        <joint class="Y"/>
                        <geom class="Z"/>
                        <body name="B1" pos="0 0 0.152">
                          <joint class="X"/>
                          <joint class="Y"/>
                          <geom class="Z"/>
                          <body name="B2" pos="0 0 0.152">
                            <joint class="X"/>
                            <joint class="Y"/>
                            <geom class="Z"/>
                            <body name="B3" pos="0 0 0.152">
                              <joint class="X"/>
                              <joint class="Y"/>
                              <geom class="Z"/>
                              <body name="B4" pos="0 0 0.152">
                                <joint class="X"/>
                                <joint class="Y"/>
                                <geom class="Z"/>
                              </body>
                            </body>
                          </body>
                        </body>
                      </body>  <!-- 0-->
                      <!-- whip end -->
                  </body>
                </body>
              </body>
            </body>
          </body>
        </body>
      </body>
    </body>

    <body name="target" pos="-0.9 0 1.0">
      <geom type="sphere" size=".05 .05 .05" class="target" name="target"/>
      <freejoint name="target"/>
    </body>
  </worldbody>

  <actuator>
    <general class="panda" name="actuator1" joint="joint1" gainprm="4500" biasprm="0 -4500 -450"/>
    <general class="panda" name="actuator2" joint="joint2" gainprm="4500" biasprm="0 -4500 -450"
      ctrlrange="-1.7628 1.7628"/>
    <general class="panda" name="actuator3" joint="joint3" gainprm="3500" biasprm="0 -3500 -350"/>
    <general class="panda" name="actuator4" joint="joint4" gainprm="3500" biasprm="0 -3500 -350"
      ctrlrange="-3.0718 -0.0698"/>
    <general class="panda" name="actuator5" joint="joint5" gainprm="2000" biasprm="0 -2000 -200" forcerange="-12 12"/>
    <general class="panda" name="actuator6" joint="joint6" gainprm="2000" biasprm="0 -2000 -200" forcerange="-12 12"
      ctrlrange="-0.0175 3.7525"/>
    <general class="panda" name="actuator7" joint="joint7" gainprm="2000" biasprm="0 -2000 -200" forcerange="-12 12"/>
  </actuator>

</mujoco>
baixianger commented 10 months ago

If I increased the nested amount of the rope above 5, the generalized pipeline step will give me nan states.

Step result state.x.pos under 5 nested rope setting.

ID      : Name   Parent  Type    DoF
1       : link0  -1      2       [0, 1]
2       : link1  0       1       [2]
3       : link2  1       1       [3]
4       : link3  2       1       [4]
5       : link4  3       1       [5]
6       : link5  4       1       [6]
7       : link6  5       1       [7]
8       : link7  6       1       [8]
9       : B0     7       2       [9, 10]
10      : B1     8       2       [11, 12]
11      : B2     9       2       [13, 14]
12      : B3     10      2       [15, 16]
13      : B4     11      2       [17, 18]
14      : target         -1      f       [19, 20, 21, 22, 23, 24]
[[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
 [-8.0064544e-04  6.3779644e-06  3.3299905e-01]
 [-8.0064544e-04  6.3779644e-06  3.3299905e-01]
 [-1.0635098e-03  1.2516281e-05  6.4899898e-01]
 [ 8.1436455e-02  6.2528095e-05  6.4906758e-01]
 [ 2.2547080e-01  1.5716164e-04  1.0144665e+00]
 [ 2.2547080e-01  1.5716164e-04  1.0144665e+00]
 [ 3.0516160e-01  3.3137485e-04  9.7713888e-01]
 [ 2.4068651e-01  3.2239131e-04  8.3949089e-01]
 [ 1.7497984e-01  9.3778474e-03  7.0272601e-01]
 [ 1.1845386e-01  4.1922983e-03  5.6172270e-01]
 [ 6.2959567e-02  6.3486844e-03  4.2023158e-01]
 [ 1.1395495e-02  4.7966195e-03  2.7725345e-01]
 [-8.9999998e-01  0.0000000e+00  9.3648243e-01]]

Step result state.x.pos under 6 more nested rope settings.

ID      : Name   Parent  Type    DoF
1       : link0  -1      2       [0, 1]
2       : link1  0       1       [2]
3       : link2  1       1       [3]
4       : link3  2       1       [4]
5       : link4  3       1       [5]
6       : link5  4       1       [6]
7       : link6  5       1       [7]
8       : link7  6       1       [8]
9       : B0     7       2       [9, 10]
10      : B1     8       2       [11, 12]
11      : B2     9       2       [13, 14]
12      : B3     10      2       [15, 16]
13      : B4     11      2       [17, 18]
14      : B5     12      2       [19, 20]
15      : target         -1      f       [21, 22, 23, 24, 25, 26]
[[nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]]

[*]: The first joint with 2-dof is used to fix the link-joint problem; the links should be equal to joints at Brax.

baixianger commented 10 months ago

I did another experiment about nested bodies

MJCF file:

<mujoco>
  <option timestep="0.01"/>
  <worldbody>
    <body pos="0 0 -0.5">
      <joint name="joint1" axis="0 1 0" type="hinge"/>
      <geom pos="0 0 -1.5" size=".1" type="sphere"/>
      <body pos="0 0 -1.5">
        <joint name="joint2" axis="0 1 0" type="hinge"/>
        <geom pos="0 0 -1.5" size=".1" type="sphere"/>
        <body pos="0 0 -1.5">
          <joint axis="0 1 0" type="hinge"/>
          <geom pos="0 0 -1.5" size=".1" type="sphere"/>
          <body pos="0 0 -1.5">
            <joint axis="0 1 0" type="hinge"/>
            <geom pos="0 0 -1.5" size=".1" type="sphere"/>
            <body pos="0 0 -1.5">
              <joint axis="0 1 0" type="hinge"/>
              <geom pos="0 0 -1.5" size=".1" type="sphere"/>
              <body pos="0 0 -1.5">
                <joint axis="0 1 0" type="hinge"/>
                <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                <body pos="0 0 -1.5">
                  <joint axis="0 1 0" type="hinge"/>
                  <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                  <body pos="0 0 -1.5">
                    <joint axis="0 1 0" type="hinge"/>
                    <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                    <body pos="0 0 -1.5">
                      <joint axis="0 1 0" type="hinge"/>
                      <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                      <body pos="0 0 -1.5">
                        <joint axis="0 1 0" type="hinge"/>
                        <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                        <body pos="0 0 -1.5">
                          <joint axis="0 1 0" type="hinge"/>
                          <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                          <body pos="0 0 -1.5">
                            <joint axis="0 1 0" type="hinge"/>
                            <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                            <body pos="0 0 -1.5">
                              <joint axis="0 1 0" type="hinge"/>
                              <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                              <body pos="0 0 -1.5">
                                <joint axis="0 1 0" type="hinge"/>
                                <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                                <body pos="0 0 -1.5">
                                  <joint axis="0 1 0" type="hinge"/>
                                  <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                                  <body pos="0 0 -1.5">
                                    <joint axis="0 1 0" type="hinge"/>
                                    <geom pos="0 0 -1.5" size=".1" type="sphere"/>
                                  </body>
                                </body>
                              </body>
                            </body>
                          </body>
                        </body>
                      </body>
                    </body>
                  </body>
                </body>
              </body>
            </body>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <position joint="joint1" kp="500"/>
    <position joint="joint2" kp="500"/>
  </actuator>
</mujoco>

Code:

import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import jax
from jax import numpy as jp
from brax.io import mjcf
from brax.generalized import pipeline
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

path = f'test/_nested_bodies.xml'
sys = mjcf.load(path)
sys_info = zip(sys.link_names, sys.link_parents, sys.link_types, sys.dof_ranges())

print('ID\t: Name\t Parent\t Type\t DoF')
for i, (name, parent, type, dof) in enumerate(sys_info):
    print(f'{i+1}\t: {name}\t {parent}\t {type}\t {dof}')

act = jp.array([0, 0])
q = sys.init_q
qd = jp.zeros(sys.qd_size())
state = jax.jit(pipeline.init)(sys, q, qd)

xpos = []
for i in range(100):
    state = jax.jit(pipeline.step)(sys, state, act)
    xpos.append(state.x.pos)

print(xpos[-1])

Output:

ID      : Name   Parent  Type    DoF
1       :        -1      1       [0]
2       :        0       1       [1]
3       :        1       1       [2]
4       :        2       1       [3]
5       :        3       1       [4]
6       :        4       1       [5]
7       :        5       1       [6]
8       :        6       1       [7]
9       :        7       1       [8]
10      :        8       1       [9]
11      :        9       1       [10]
12      :        10      1       [11]
13      :        11      1       [12]
14      :        12      1       [13]
15      :        13      1       [14]
16      :        14      1       [15]
[[nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]

It turned out the same result full of nan in the returned state.

btaba commented 8 months ago

Hi @baixianger , thanks for the bug report! Looks like a certain part of the generalized backend is unstable with this nested rope. If you look at https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debugging-flags , there are debugging flags for NaNs

from jax import config
config.update("jax_debug_nans", True)

If you run with that, you may get an indication about where the NaNs are coming from, which would give hints about what to change in the model to make the calculation more stable (perhaps mass/inertia/timestep).

I would also suggest turning off self-collisions with the panda arm unless absolutely necessary, convex<>convex collisions are expensive to compute in Brax.

baixianger commented 8 months ago

Hi @baixianger , thanks for the bug report! Looks like a certain part of the generalized backend is unstable with this nested rope. If you look at https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debugging-flags , there are debugging flags for NaNs

from jax import config
config.update("jax_debug_nans", True)

If you run with that, you may get an indication about where the NaNs are coming from, which would give hints about what to change in the model to make the calculation more stable (perhaps mass/inertia/timestep).

I would also suggest turning off self-collisions with the panda arm unless absolutely necessary, convex<>convex collisions are expensive to compute in Brax.

Some gossips beyond this topic, Today Mujoco 3 is released. I realized why your teams replied so late during those months. I guess you all were working on Mujoco3. LOL.

I really appreciate your jobs both on Brax and Mujoco3.

btaba commented 8 months ago

Haha @baixianger , yeah sorry about the late replies and thanks for bearing with us...

btaba commented 8 months ago

@baixianger OOC have you tried this with MJX?