google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Nan encounted in pipeline_step() #474

Closed MasterXiong closed 2 months ago

MasterXiong commented 2 months ago

Hi,

I encountered a similar issue as #467 when running simulation on a robot with random actions. I followed the suggestions mentioned in this thread but still observed nan in simulation. Below is the minimal code and xml file to reproduce the issue. Could you help have a check on what may be the issue here? Thanks a lot!

from brax import base
from brax import math
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
import jax
from jax import numpy as jp
import mujoco

class Unimal(PipelineEnv):
  def __init__(
      self,
      xml_path, 
      backend='generalized',
      **kwargs,
  ):
    sys = mjcf.load(xml_path)

    sys = sys.replace(dt=0.005)
    n_frames = 5
    kwargs['n_frames'] = kwargs.get('n_frames', n_frames)

    super().__init__(sys=sys, backend=backend, **kwargs)

    self.get_action_index()
    self._reset_noise_scale = 0.1

  def get_action_index(self):
    # mask the joints for each limb
    self.limb_num = self.sys.num_links()
    dof_link_idx = self.sys.dof_link()[6:].copy()
    repeat_mask = (dof_link_idx[1:] == dof_link_idx[:-1])
    repeat_mask = jp.insert(repeat_mask, 0, 0)
    self.action_index = dof_link_idx * 2 + repeat_mask

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi)

    data = self.pipeline_init(qpos, qvel)
    obs = None

    reward, done, zero = jp.zeros(3)
    metrics = {}
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Run one timestep of the environment's dynamics."""

    # remove useless action dimensions
    action = action[self.action_index]

    # step
    pipeline_state0 = state.pipeline_state
    pipeline_state = self.pipeline_step(pipeline_state0, action)

    obs = None
    reward, done = 0., 0.

    return state.replace(
        pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
    )

xml_path = 'robot.xml'
agent = Unimal(xml_path)
action_dim = agent.sys.num_links() * 2

jit_env_reset = jax.jit(agent.reset)
jit_env_step = jax.jit(agent.step)

episode_length = 2560
random_action = jax.random.normal(jax.random.PRNGKey(seed=1), shape=(episode_length, action_dim))

state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))
for t in range(episode_length):
    state = jit_env_step(state, random_action[t])
    print (state.pipeline_state.q)
    if jp.any(jp.isnan(state.pipeline_state.q)):
      print (t)
      break

And the xml configuration is

<?xml version='1.0' encoding='UTF-8'?>
<!-- Universal Animal Template: unimal -->
<mujoco model="unimal">
  <compiler angle="degree"/>
  <size njmax="2000" nconmax="500"/>
  <option timestep=".005">
    <flag filterparent="disable"/>
  </option>
  <!-- Common defaults to make search space tractable -->
  <default>
    <!-- Define motor defaults -->
    <motor ctrlrange="-1 1" ctrllimited="true"/>
    <!-- Define joint defaults -->
    <default class="normal_joint">
      <joint type="hinge" damping="1" stiffness="1" armature="1" limited="true" range="-120 120" solimplimit="0 0.99 0.01"/>
    </default>
    <default class="walker_joint">
      <joint type="hinge" damping="0.2" stiffness="1" armature=".01" limited="true" range="-120 120" solimplimit="0 0.99 0.01"/>
    </default>
    <default class="stiff_joint">
      <joint type="hinge" damping="5" stiffness="10" armature=".01" limited="true" solimplimit="0 0.99 0.01"/>
    </default>
    <default class="free">
      <joint limited="false" damping="0" armature="0" stiffness="0"/>
    </default>
    <default class="growth_site">
      <site size="1e-6 1e-6 1e-6"/>
    </default>
    <default class="torso_growth_site">
      <site size="1e-6 1e-6 1e-6"/>
    </default>
    <default class="mirror_growth_site">
      <site size="1e-6 1e-6 1e-6"/>
    </default>
    <default class="btm_pos_site">
      <site size="1e-6 1e-6 1e-6"/>
    </default>
    <default class="box_face_site">
      <site size="1e-6 1e-6 1e-6"/>
    </default>
    <default class="imu_vel">
      <site type="box" size="0.05" rgba="1 0 0 1"/>
    </default>
    <default class="touch_site">
      <site group="3" rgba="0 0 1 .3"/>
    </default>
    <default class="food_site">
      <site material="food" size="0.15"/>
    </default>
    <!-- Define geom defaults -->
    <geom type="capsule" condim="3" friction="0.7 0.1 0.1" material="self"/>
  </default>
  <worldbody>
    <light diffuse="1 1 1" directional="true" exponent="1" pos="0 0 1" specular=".1 .1 .1"/>
    <!-- <geom name="floor" type="plane" pos="0 0 0" size="50 50 1" material="grid"/> -->
    <!-- Programatically generated xml goes here -->
    <body name="torso/0" pos="0 0 0.75">
      <joint name="root" type="free" class="free"/>
      <site name="root" class="imu_vel"/>
      <geom name="torso/0" type="sphere" size="0.1" condim="3" density="1000"/>
      <camera name="side" pos="0 -7 2" xyaxes="1 0 0 0 1 2" mode="trackcom"/>
      <site name="torso/0" class="growth_site" pos="0 0 0"/>
      <site name="torso/btm_pos/0" class="btm_pos_site" pos="0 0 -0.1"/>
      <site name="torso/touch/0" class="touch_site" size="0.11"/>
      <site name="torso/horizontal_y/0" class="torso_growth_site" pos="-0.1 0 0"/>
      <body name="limb/0" pos="0.0 0.0 -0.1">
        <joint name="limbx/0" type="hinge" class="normal_joint" range="0 60" pos="0.0 0.0 0.05" axis="1.0 0.0 0.0"/>
        <joint name="limby/0" type="hinge" class="normal_joint" range="-60 30" pos="0.0 0.0 0.05" axis="0.0 1.0 0.0"/>
        <geom name="limb/0" type="capsule" fromto="0.0 0.0 0.0 0.0 0.0 -0.45" size="0.05" density="600"/>
        <site name="limb/mid/0" class="growth_site" pos="0.0 0.0 -0.25"/>
        <site name="limb/btm/0" class="growth_site" pos="0.0 0.0 -0.45"/>
        <site name="limb/btm_pos/0" class="btm_pos_site" pos="0.0 0.0 -0.45"/>
        <site name="limb/touch/0" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 0.0 0.0 -0.45" type="capsule"/>
      </body>
      <body name="limb/4" pos="-0.1 0.0 0.0">
        <joint name="limby/4" type="hinge" class="normal_joint" range="-30 60" pos="0.05 0.0 0.0" axis="0.0 1.0 0.0"/>
        <geom name="limb/4" type="capsule" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" size="0.05" density="600"/>
        <site name="limb/mid/4" class="growth_site" pos="-0.25 0.0 0.0"/>
        <site name="limb/btm/4" class="growth_site" pos="-0.45 0.0 0.0"/>
        <site name="limb/btm_pos/4" class="btm_pos_site" pos="-0.5 0.0 0.0"/>
        <site name="limb/touch/4" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" type="capsule"/>
        <body name="limb/5" pos="-0.45 0.05 0.0">
          <joint name="limbx/5" type="hinge" class="normal_joint" range="-60 30" pos="0.0 -0.05 0.0" axis="1.0 0.0 0.0"/>
          <geom name="limb/5" type="capsule" fromto="0.0 0.0 0.0 0.0 0.45 0.0" size="0.05" density="600"/>
          <site name="limb/mid/5" class="mirror_growth_site" pos="0.0 0.25 0.0"/>
          <site name="limb/btm/5" class="mirror_growth_site" pos="0.0 0.45 0.0"/>
          <site name="limb/btm_pos/5" class="btm_pos_site" pos="0.0 0.5 0.0"/>
          <site name="limb/touch/5" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 0.0 0.45 0.0" type="capsule"/>
          <body name="limb/7" pos="-0.05 0.45 0.0">
            <joint name="limbx/7" type="hinge" class="normal_joint" range="-60 30" pos="0.05 0.0 0.0" axis="0.0 0.0 -1.0"/>
            <joint name="limby/7" type="hinge" class="normal_joint" range="-30 60" pos="0.05 0.0 0.0" axis="0.0 1.0 0.0"/>
            <geom name="limb/7" type="capsule" fromto="0.0 0.0 0.0 -0.25 0.0 0.0" size="0.05" density="600"/>
            <site name="limb/mid/7" class="mirror_growth_site" pos="-0.15 0.0 0.0"/>
            <site name="limb/btm/7" class="mirror_growth_site" pos="-0.25 0.0 0.0"/>
            <site name="limb/btm_pos/7" class="btm_pos_site" pos="-0.3 0.0 0.0"/>
            <site name="limb/touch/7" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.25 0.0 0.0" type="capsule"/>
            <body name="limb/9" pos="-0.3 0.0 0.0">
              <joint name="limby/9" type="hinge" class="normal_joint" range="-30 30" pos="0.05 0.0 0.0" axis="0.0 1.0 0.0"/>
              <geom name="limb/9" type="capsule" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" size="0.05" density="600"/>
              <site name="limb/mid/9" class="mirror_growth_site" pos="-0.25 0.0 0.0"/>
              <site name="limb/btm/9" class="mirror_growth_site" pos="-0.45 0.0 0.0"/>
              <site name="limb/btm_pos/9" class="btm_pos_site" pos="-0.5 0.0 0.0"/>
              <site name="limb/touch/9" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" type="capsule"/>
              <body name="limb/12" pos="-0.49 0.0 -0.04">
                <joint name="limby/12" type="hinge" class="normal_joint" range="-90 0" pos="0.04 0.0 0.04" axis="0.0 1.0 0.0"/>
                <geom name="limb/12" type="capsule" fromto="0.0 0.0 0.0 -0.25 0.0 -0.25" size="0.05" density="600"/>
                <site name="limb/mid/12" class="mirror_growth_site" pos="-0.14 0.0 -0.14"/>
                <site name="limb/btm/12" class="mirror_growth_site" pos="-0.25 0.0 -0.25"/>
                <site name="limb/btm_pos/12" class="btm_pos_site" pos="-0.28 0.0 -0.28"/>
                <site name="limb/touch/12" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.25 0.0 -0.25" type="capsule"/>
              </body>
            </body>
          </body>
        </body>
        <body name="limb/6" pos="-0.45 -0.05 0.0">
          <joint name="limbx/6" type="hinge" class="normal_joint" range="-60 30" pos="0.0 0.05 0.0" axis="1.0 0.0 -0.0"/>
          <geom name="limb/6" type="capsule" fromto="0.0 0.0 0.0 0.0 -0.45 0.0" size="0.05" density="600"/>
          <site name="limb/mid/6" class="mirror_growth_site" pos="0.0 -0.25 0.0"/>
          <site name="limb/btm/6" class="mirror_growth_site" pos="0.0 -0.45 0.0"/>
          <site name="limb/btm_pos/6" class="btm_pos_site" pos="0.0 -0.5 0.0"/>
          <site name="limb/touch/6" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 0.0 -0.45 0.0" type="capsule"/>
          <body name="limb/8" pos="-0.05 -0.45 0.0">
            <joint name="limbx/8" type="hinge" class="normal_joint" range="-60 30" pos="0.05 0.0 0.0" axis="0.0 0.0 -1.0"/>
            <joint name="limby/8" type="hinge" class="normal_joint" range="-30 60" pos="0.05 0.0 0.0" axis="0.0 1.0 0.0"/>
            <geom name="limb/8" type="capsule" fromto="0.0 0.0 0.0 -0.25 0.0 0.0" size="0.05" density="600"/>
            <site name="limb/mid/8" class="mirror_growth_site" pos="-0.15 0.0 0.0"/>
            <site name="limb/btm/8" class="mirror_growth_site" pos="-0.25 0.0 0.0"/>
            <site name="limb/btm_pos/8" class="btm_pos_site" pos="-0.3 0.0 0.0"/>
            <site name="limb/touch/8" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.25 0.0 0.0" type="capsule"/>
            <body name="limb/10" pos="-0.3 0.0 0.0">
              <joint name="limby/10" type="hinge" class="normal_joint" range="-30 30" pos="0.05 0.0 0.0" axis="0.0 1.0 0.0"/>
              <geom name="limb/10" type="capsule" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" size="0.05" density="600"/>
              <site name="limb/mid/10" class="mirror_growth_site" pos="-0.25 0.0 0.0"/>
              <site name="limb/btm/10" class="mirror_growth_site" pos="-0.45 0.0 0.0"/>
              <site name="limb/btm_pos/10" class="btm_pos_site" pos="-0.5 0.0 0.0"/>
              <site name="limb/touch/10" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" type="capsule"/>
              <body name="limb/11" pos="-0.49 0.0 -0.04">
                <joint name="limby/11" type="hinge" class="normal_joint" range="-90 0" pos="0.04 0.0 0.04" axis="0.0 1.0 0.0"/>
                <geom name="limb/11" type="capsule" fromto="0.0 0.0 0.0 -0.25 0.0 -0.25" size="0.05" density="600"/>
                <site name="limb/mid/11" class="mirror_growth_site" pos="-0.14 0.0 -0.14"/>
                <site name="limb/btm/11" class="mirror_growth_site" pos="-0.25 0.0 -0.25"/>
                <site name="limb/btm_pos/11" class="btm_pos_site" pos="-0.28 0.0 -0.28"/>
                <site name="limb/touch/11" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.25 0.0 -0.25" type="capsule"/>
              </body>
            </body>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor joint="limbx/0" gear="200" name="limbx/0"/>
    <motor joint="limby/0" gear="300" name="limby/0"/>
    <motor joint="limby/4" gear="150" name="limby/4"/>
    <motor joint="limbx/5" gear="300" name="limbx/5"/>
    <motor joint="limbx/7" gear="250" name="limbx/7"/>
    <motor joint="limby/7" gear="250" name="limby/7"/>
    <motor joint="limby/9" gear="150" name="limby/9"/>
    <motor joint="limby/12" gear="150" name="limby/12"/>
    <motor joint="limbx/6" gear="300" name="limbx/6"/>
    <motor joint="limbx/8" gear="250" name="limbx/8"/>
    <motor joint="limby/8" gear="250" name="limby/8"/>
    <motor joint="limby/10" gear="150" name="limby/10"/>
    <motor joint="limby/11" gear="150" name="limby/11"/>
  </actuator>
  <sensor>
    <accelerometer name="torso_accel" site="root"/>
    <gyro name="torso_gyro" site="root"/>
    <velocimeter name="torso_vel" site="root"/>
    <subtreeangmom name="unimal_am" body="torso/0"/>
    <touch name="torso/0" site="torso/touch/0"/>
    <touch name="limb/0" site="limb/touch/0"/>
    <touch name="limb/4" site="limb/touch/4"/>
    <touch name="limb/5" site="limb/touch/5"/>
    <touch name="limb/7" site="limb/touch/7"/>
    <touch name="limb/9" site="limb/touch/9"/>
    <touch name="limb/12" site="limb/touch/12"/>
    <touch name="limb/6" site="limb/touch/6"/>
    <touch name="limb/8" site="limb/touch/8"/>
    <touch name="limb/10" site="limb/touch/10"/>
    <touch name="limb/11" site="limb/touch/11"/>
  </sensor>
  <!-- Add hfield assets -->
  <asset/>
  <!-- List of contacts to exclude -->
  <contact/>
  <!-- Define material, texture etc -->
  <asset>
    <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance="0"/>
    <material name="hfield" texture="hfield" texrepeat="1 1" texuniform="true" reflectance="0"/>
    <material name="wall" texture="wall" texrepeat="1 1" texuniform="true" reflectance=".5"/>
    <material name="platform" texture="platform" texrepeat="1 1" texuniform="true" reflectance=".5"/>
    <material name="boundary" texture="boundary" texrepeat="1 1" texuniform="true" reflectance=".5"/>
    <material name="jump" texture="jump" texrepeat="1 1" texuniform="true" reflectance=".5"/>
    <material name="goal" rgba="1 0 0 1"/>
    <material name="food" rgba="0 0 1 1" emission="1"/>
    <material name="self" rgba=".7 .5 .3 1"/>
    <material name="self_default" rgba=".7 .5 .3 1"/>
    <material name="self_highlight" rgba="0 .5 .3 1"/>
    <material name="effector" rgba=".7 .4 .2 1"/>
    <material name="effector_default" rgba=".7 .4 .2 1"/>
    <material name="effector_highlight" rgba="0 .5 .3 1"/>
    <material name="decoration" rgba=".3 .5 .7 1"/>
    <material name="eye" rgba="0 .2 1 1"/>
    <material name="target" rgba=".6 .3 .3 1"/>
    <material name="target_default" rgba=".6 .3 .3 1"/>
    <material name="target_highlight" rgba=".6 .3 .3 .4"/>
    <material name="site" rgba=".5 .5 .5 .3"/>
    <material name="ball" texture="ball"/>
  </asset>
  <asset>
    <texture name="grid" type="2d" builtin="checker" rgb1="0.1 0.1 0.1" rgb2="0.1 0.1 0.1" width="300" height="300" mark="edge" markrgb="0.2 0.2 0.2"/>
    <texture name="hfield" type="2d" builtin="checker" rgb1="0.1 0.1 0.1" rgb2="0.1 0.1 0.1" width="300" height="300"/>
    <texture name="wall" type="2d" builtin="flat" rgb1="0.9 0.7 0" rgb2="0.9 0.7 0" width="300" height="300"/>
    <texture name="platform" type="2d" builtin="flat" rgb1="0.3 0 0.8" rgb2="0.3 0 0.8" width="300" height="300"/>
    <texture name="boundary" type="2d" builtin="flat" rgb1="0.3 0.3 0.3" rgb2="0.3 0.3 0.3" width="300" height="300"/>
    <texture name="jump" type="2d" builtin="flat" rgb1="0.3 0.3 0.3" rgb2="0.3 0.3 0.3" width="300" height="300"/>
    <texture name="skybox" type="skybox" builtin="flat" rgb1="0.8 1 1" rgb2="0.8 1 1" width="800" height="800"/>
    <texture name="ball" builtin="checker" mark="cross" width="151" height="151" rgb1="0.1 0.1 0.1" rgb2="0.9 0.9 0.9" markrgb="1 1 1"/>
  </asset>
  <visual>
    <headlight ambient=".4 .4 .4" diffuse=".8 .8 .8" specular="0.1 0.1 0.1"/>
    <map znear=".01"/>
    <quality shadowsize="2048"/>
  </visual>
</mujoco>
btaba commented 2 months ago

Hi @MasterXiong , I loaded the model and visualized with

mjpython -m mujoco.mjx.viewer --mjcf=tmp.xml

and it looks like the fixture falls forever. I removed the free joint at the root.

Then I ran your code, but it seems like there are some bugs in your code (your env is producing incompatible shapes in the actions).

I ran this instead:

m = mujoco.MjModel.from_xml_path(path)
mx = mjx.put_model(m)
dx = mjx.make_data(mx)

for _ in range(3000):
  dx = dx.replace(ctrl=np.random.uniform(low=-1, high=1, size=(mx.nu,)))
  dx = jax.jit(mjx.step)(mx, dx)

and the results seem OK