Open baixianger opened 10 months ago
If I increased the nested amount of the rope above 5, the generalized pipeline step will give me nan
states.
ID : Name Parent Type DoF
1 : link0 -1 2 [0, 1]
2 : link1 0 1 [2]
3 : link2 1 1 [3]
4 : link3 2 1 [4]
5 : link4 3 1 [5]
6 : link5 4 1 [6]
7 : link6 5 1 [7]
8 : link7 6 1 [8]
9 : B0 7 2 [9, 10]
10 : B1 8 2 [11, 12]
11 : B2 9 2 [13, 14]
12 : B3 10 2 [15, 16]
13 : B4 11 2 [17, 18]
14 : target -1 f [19, 20, 21, 22, 23, 24]
[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[-8.0064544e-04 6.3779644e-06 3.3299905e-01]
[-8.0064544e-04 6.3779644e-06 3.3299905e-01]
[-1.0635098e-03 1.2516281e-05 6.4899898e-01]
[ 8.1436455e-02 6.2528095e-05 6.4906758e-01]
[ 2.2547080e-01 1.5716164e-04 1.0144665e+00]
[ 2.2547080e-01 1.5716164e-04 1.0144665e+00]
[ 3.0516160e-01 3.3137485e-04 9.7713888e-01]
[ 2.4068651e-01 3.2239131e-04 8.3949089e-01]
[ 1.7497984e-01 9.3778474e-03 7.0272601e-01]
[ 1.1845386e-01 4.1922983e-03 5.6172270e-01]
[ 6.2959567e-02 6.3486844e-03 4.2023158e-01]
[ 1.1395495e-02 4.7966195e-03 2.7725345e-01]
[-8.9999998e-01 0.0000000e+00 9.3648243e-01]]
ID : Name Parent Type DoF
1 : link0 -1 2 [0, 1]
2 : link1 0 1 [2]
3 : link2 1 1 [3]
4 : link3 2 1 [4]
5 : link4 3 1 [5]
6 : link5 4 1 [6]
7 : link6 5 1 [7]
8 : link7 6 1 [8]
9 : B0 7 2 [9, 10]
10 : B1 8 2 [11, 12]
11 : B2 9 2 [13, 14]
12 : B3 10 2 [15, 16]
13 : B4 11 2 [17, 18]
14 : B5 12 2 [19, 20]
15 : target -1 f [21, 22, 23, 24, 25, 26]
[[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]]
[*]: The first joint with 2-dof is used to fix the link-joint problem; the links should be equal to joints at Brax.
<mujoco>
<option timestep="0.01"/>
<worldbody>
<body pos="0 0 -0.5">
<joint name="joint1" axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint name="joint2" axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
<body pos="0 0 -1.5">
<joint axis="0 1 0" type="hinge"/>
<geom pos="0 0 -1.5" size=".1" type="sphere"/>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</worldbody>
<actuator>
<position joint="joint1" kp="500"/>
<position joint="joint2" kp="500"/>
</actuator>
</mujoco>
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import jax
from jax import numpy as jp
from brax.io import mjcf
from brax.generalized import pipeline
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
path = f'test/_nested_bodies.xml'
sys = mjcf.load(path)
sys_info = zip(sys.link_names, sys.link_parents, sys.link_types, sys.dof_ranges())
print('ID\t: Name\t Parent\t Type\t DoF')
for i, (name, parent, type, dof) in enumerate(sys_info):
print(f'{i+1}\t: {name}\t {parent}\t {type}\t {dof}')
act = jp.array([0, 0])
q = sys.init_q
qd = jp.zeros(sys.qd_size())
state = jax.jit(pipeline.init)(sys, q, qd)
xpos = []
for i in range(100):
state = jax.jit(pipeline.step)(sys, state, act)
xpos.append(state.x.pos)
print(xpos[-1])
ID : Name Parent Type DoF
1 : -1 1 [0]
2 : 0 1 [1]
3 : 1 1 [2]
4 : 2 1 [3]
5 : 3 1 [4]
6 : 4 1 [5]
7 : 5 1 [6]
8 : 6 1 [7]
9 : 7 1 [8]
10 : 8 1 [9]
11 : 9 1 [10]
12 : 10 1 [11]
13 : 11 1 [12]
14 : 12 1 [13]
15 : 13 1 [14]
16 : 14 1 [15]
[[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
It turned out the same result full of nan in the returned state.
Hi @baixianger , thanks for the bug report! Looks like a certain part of the generalized backend is unstable with this nested rope. If you look at https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debugging-flags , there are debugging flags for NaNs
from jax import config
config.update("jax_debug_nans", True)
If you run with that, you may get an indication about where the NaNs are coming from, which would give hints about what to change in the model to make the calculation more stable (perhaps mass/inertia/timestep).
I would also suggest turning off self-collisions with the panda arm unless absolutely necessary, convex<>convex collisions are expensive to compute in Brax.
Hi @baixianger , thanks for the bug report! Looks like a certain part of the generalized backend is unstable with this nested rope. If you look at https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debugging-flags , there are debugging flags for NaNs
from jax import config config.update("jax_debug_nans", True)
If you run with that, you may get an indication about where the NaNs are coming from, which would give hints about what to change in the model to make the calculation more stable (perhaps mass/inertia/timestep).
I would also suggest turning off self-collisions with the panda arm unless absolutely necessary, convex<>convex collisions are expensive to compute in Brax.
Some gossips beyond this topic, Today Mujoco 3 is released. I realized why your teams replied so late during those months. I guess you all were working on Mujoco3. LOL.
I really appreciate your jobs both on Brax and Mujoco3.
Haha @baixianger , yeah sorry about the late replies and thanks for bearing with us...
@baixianger OOC have you tried this with MJX?
I implemented my rope with nested rigid bodies and joints, and I did some experiments on Brax. I found that Brax only supports limited nested bodies.
Would you happen to have any solution for a rope?