google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

episode reward is NAN by training #467

Closed naijekux closed 3 months ago

naijekux commented 4 months ago

Hi,

I used the pusher example env with a little modification to import MJCF of no-hand panda robot arm. The code for training remained also unchanged, the same as the example in colab for pusher task.

But the episode reward is NaN value when ploting the training result, shown in the screenshot below. After it's plotted the first time without any curve on it, the second plot won't be given anymore. (The screenshot presented the outputs at 9min). But GPU keeps running during the whole time, meaning that the code runing is not in idle state. Screenshot from 2024-03-06 12-36-40

According to the print for the start and end of functions reset() and step(), it begins to be frozen after the 1st step(). Thus this problem would be caused during training process. I'm sure that there're no operations in my own code such as being divided by 0, which generally cause the NaN error.

The XML file for the panda robot used is as below

<mujoco model="panda nohand">
  <compiler angle="radian" meshdir="assets" autolimits="true"/>

  <option timestep="0.01" gravity="0 0 0" iterations="20"/>

  <default>
    <default class="panda">
      <material specular="0.5" shininess="0.25"/>
      <joint armature="0.1" damping="1" axis="0 0 1" range="-2.8973 2.8973"/>
      <general dyntype="none" biastype="affine" ctrlrange="-2.8973 2.8973" forcerange="-87 87" ctrllimited="true"/>

      <default class="visual">
        <geom type="mesh" contype="0" conaffinity="0" group="2"/>
      </default>
      <default class="collision">
        <geom type="mesh" contype="0" conaffinity="1" group="3"/>
      </default>
      <site size="0.001" rgba="0.5 0.5 0.5 0.3" group="4"/>
    </default>
  </default>

  <asset>
    <material class="panda" name="white" rgba="1 1 1 1"/>
    <material class="panda" name="off_white" rgba="0.901961 0.921569 0.929412 1"/>
    <material class="panda" name="dark_grey" rgba="0.25 0.25 0.25 1"/>
    <material class="panda" name="green" rgba="0 1 0 1"/>
    <material class="panda" name="light_blue" rgba="0.039216 0.541176 0.780392 1"/>
    <!--
    <material class="panda" name="red" rgba="1 0 0 1"/>
    -->

    <!-- Collision meshes -->
    <mesh name="link0_c" file="link0.stl"/>
    <mesh name="link1_c" file="link1.stl"/>
    <mesh name="link2_c" file="link2.stl"/>
    <mesh name="link3_c" file="link3.stl"/>
    <mesh name="link4_c" file="link4.stl"/>
    <mesh name="link5_c0" file="link5_collision_0.obj"/>
    <mesh name="link5_c1" file="link5_collision_1.obj"/>
    <mesh name="link5_c2" file="link5_collision_2.obj"/>
    <mesh name="link6_c" file="link6.stl"/>
    <mesh name="link7_c" file="link7.stl"/>

    <!-- Visual meshes -->
    <mesh file="link0_0.obj"/>
    <mesh file="link0_1.obj"/>
    <mesh file="link0_2.obj"/>
    <mesh file="link0_3.obj"/>
    <mesh file="link0_4.obj"/>
    <mesh file="link0_5.obj"/>
    <mesh file="link0_7.obj"/>
    <mesh file="link0_8.obj"/>
    <mesh file="link0_9.obj"/>
    <mesh file="link0_10.obj"/>
    <mesh file="link0_11.obj"/>
    <mesh file="link1.obj"/>
    <mesh file="link2.obj"/>
    <mesh file="link3_0.obj"/>
    <mesh file="link3_1.obj"/>
    <mesh file="link3_2.obj"/>
    <mesh file="link3_3.obj"/>
    <mesh file="link4_0.obj"/>
    <mesh file="link4_1.obj"/>
    <mesh file="link4_2.obj"/>
    <mesh file="link4_3.obj"/>
    <mesh file="link5_0.obj"/>
    <mesh file="link5_1.obj"/>
    <mesh file="link5_2.obj"/>
    <mesh file="link6_0.obj"/>
    <mesh file="link6_1.obj"/>
    <mesh file="link6_2.obj"/>
    <mesh file="link6_3.obj"/>
    <mesh file="link6_4.obj"/>
    <mesh file="link6_5.obj"/>
    <mesh file="link6_6.obj"/>
    <mesh file="link6_7.obj"/>
    <mesh file="link6_8.obj"/>
    <mesh file="link6_9.obj"/>
    <mesh file="link6_10.obj"/>
    <mesh file="link6_11.obj"/>
    <mesh file="link6_12.obj"/>
    <mesh file="link6_13.obj"/>
    <mesh file="link6_14.obj"/>
    <mesh file="link6_15.obj"/>
    <mesh file="link6_16.obj"/>
    <mesh file="link7_0.obj"/>
    <mesh file="link7_1.obj"/>
    <mesh file="link7_2.obj"/>
    <mesh file="link7_3.obj"/>
    <mesh file="link7_4.obj"/>
    <mesh file="link7_5.obj"/>
    <mesh file="link7_6.obj"/>
    <mesh file="link7_7.obj"/>
  </asset>

  <worldbody>
    <light name="top" pos="0 0 2" mode="trackcom"/>

    <geom name="table" type="plane" pos="0 0.5 -0.325" size="1 1 0.1" contype="1" conaffinity="1"/>

    <body name="link1" childclass="panda" pos="1.1 -0.5 0.133" quat="0 0 0 1">
      <inertial mass="4.970684" pos="0.003875 0.002081 -0.04762"
        fullinertia="0.70337 0.70661 0.0091170 -0.00013900 0.0067720 0.019169"/>
      <joint name="joint1"/>
      <geom material="white" mesh="link1" class="visual"/>
      <geom mesh="link1_c" class="collision"/>

      <body name="link2" quat="1 -1 0 0">
        <inertial mass="0.646926" pos="-0.003141 -0.02872 0.003495"
          fullinertia="0.0079620 2.8110e-2 2.5995e-2 -3.925e-3 1.0254e-2 7.04e-4"/>
        <joint name="joint2" range="-1.7628 1.7628"/>
        <geom material="white" mesh="link2" class="visual"/>
        <geom mesh="link2_c" class="collision"/>

        <body name="link3" pos="0 -0.316 0" quat="1 1 0 0">
          <inertial mass="3.228604" pos="2.7518e-2 3.9252e-2 -6.6502e-2"
            fullinertia="3.7242e-2 3.6155e-2 1.083e-2 -4.761e-3 -1.1396e-2 -1.2805e-2"/>
          <joint name="joint3"/>
          <geom mesh="link3_0" material="white" class="visual"/>
          <geom mesh="link3_1" material="white" class="visual"/>
          <geom mesh="link3_2" material="white" class="visual"/>
          <geom mesh="link3_3" material="dark_grey" class="visual"/>
          <geom mesh="link3_c" class="collision"/>

          <body name="link4" pos="0.0825 0 0" quat="1 1 0 0">
            <inertial mass="3.587895" pos="-5.317e-2 1.04419e-1 2.7454e-2"
              fullinertia="2.5853e-2 1.9552e-2 2.8323e-2 7.796e-3 -1.332e-3 8.641e-3"/>
            <joint name="joint4" range="-3.0718 -0.0698"/>
            <geom mesh="link4_0" material="white" class="visual"/>
            <geom mesh="link4_1" material="white" class="visual"/>
            <geom mesh="link4_2" material="dark_grey" class="visual"/>
            <geom mesh="link4_3" material="white" class="visual"/>
            <geom mesh="link4_c" class="collision"/>

            <body name="link5" pos="-0.0825 0.384 0" quat="1 -1 0 0">
              <inertial mass="1.225946" pos="-1.1953e-2 4.1065e-2 -3.8437e-2"
                fullinertia="3.5549e-2 2.9474e-2 8.627e-3 -2.117e-3 -4.037e-3 2.29e-4"/>
              <joint name="joint5"/>
              <geom mesh="link5_0" material="dark_grey" class="visual"/>
              <geom mesh="link5_1" material="white" class="visual"/>
              <geom mesh="link5_2" material="white" class="visual"/>
              <geom mesh="link5_c0" class="collision"/>
              <geom mesh="link5_c1" class="collision"/>
              <geom mesh="link5_c2" class="collision"/>

              <body name="link6" quat="1 1 0 0">
                <inertial mass="1.666555" pos="6.0149e-2 -1.4117e-2 -1.0517e-2"
                  fullinertia="1.964e-3 4.354e-3 5.433e-3 1.09e-4 -1.158e-3 3.41e-4"/>
                <joint name="joint6" range="-0.0175 3.7525"/>
                <geom mesh="link6_0" material="off_white" class="visual"/>
                <geom mesh="link6_1" material="white" class="visual"/>
                <geom mesh="link6_2" material="dark_grey" class="visual"/>
                <geom mesh="link6_3" material="white" class="visual"/>
                <geom mesh="link6_4" material="white" class="visual"/>
                <geom mesh="link6_5" material="white" class="visual"/>
                <geom mesh="link6_6" material="white" class="visual"/>
                <geom mesh="link6_7" material="light_blue" class="visual"/>
                <geom mesh="link6_8" material="light_blue" class="visual"/>
                <geom mesh="link6_9" material="dark_grey" class="visual"/>
                <geom mesh="link6_10" material="dark_grey" class="visual"/>
                <geom mesh="link6_11" material="white" class="visual"/>
                <geom mesh="link6_12" material="green" class="visual"/>
                <geom mesh="link6_13" material="white" class="visual"/>
                <geom mesh="link6_14" material="dark_grey" class="visual"/>
                <geom mesh="link6_15" material="dark_grey" class="visual"/>
                <geom mesh="link6_16" material="white" class="visual"/>
                <geom mesh="link6_c" class="collision"/>

                <body name="link7" pos="0.088 0 0" quat="1 1 0 0">
                  <inertial mass="7.35522e-01" pos="1.0517e-2 -4.252e-3 6.1597e-2"
                    fullinertia="1.2516e-2 1.0027e-2 4.815e-3 -4.28e-4 -1.196e-3 -7.41e-4"/>
                  <joint name="joint7"/>
                  <geom mesh="link7_0" material="white" class="visual"/>
                  <geom mesh="link7_1" material="dark_grey" class="visual"/>
                  <geom mesh="link7_2" material="dark_grey" class="visual"/>
                  <geom mesh="link7_3" material="dark_grey" class="visual"/>
                  <geom mesh="link7_4" material="dark_grey" class="visual"/>
                  <geom mesh="link7_5" material="dark_grey" class="visual"/>
                  <geom mesh="link7_6" material="dark_grey" class="visual"/>
                  <geom mesh="link7_7" material="white" class="visual"/>
                  <geom mesh="link7_c" class="collision"/>

                  <!-- <body name="attachment" pos="0 0 0.107" quat="0.3826834 0 0 0.9238795">
                    <site name="attachment_site"/>
                  </body> -->
                </body>
              </body>
            </body>
          </body>
        </body>
      </body>
    </body>

    <body name="object" pos="0.65 -0.15 -0.275" >
      <geom rgba="1 1 1 0" type="sphere" size="0.05 0.05 0.05" density="1000" conaffinity="0"/>
      <geom rgba="1 1 1 1" type="box" size="0.05 0.05 0.05" density="1000" contype="1" conaffinity="0"/>
      <joint name="obj_slidey" type="slide" pos="0 0 0" axis="0 1 0" range="-10.3213 10.3" damping="0.5"/>
      <joint name="obj_slidex" type="slide" pos="0 0 0" axis="1 0 0" range="-10.3213 10.3" damping="0.5"/>
    </body>    
    <body name="goal" pos="0.65 -0.15 -0.3240">
      <geom rgba="1 0 0 1" type="box" size="0.08 0.08 0.001" density='0.00001' contype="0" conaffinity="0"/>
      <joint name="goal_slidey" type="slide" pos="0 0 0" axis="0 1 0" range="-10.3213 10.3" damping="0.5"/>
      <joint name="goal_slidex" type="slide" pos="0 0 0" axis="1 0 0" range="-10.3213 10.3" damping="0.5"/>
    </body>
  </worldbody>

  <actuator>
    <general class="panda" name="actuator1" joint="joint1"/>
    <general class="panda" name="actuator2" joint="joint2" ctrlrange="-1.7628 1.7628"/>
    <general class="panda" name="actuator3" joint="joint3" />
    <general class="panda" name="actuator4" joint="joint4" ctrlrange="-3.0718 -0.0698"/>
    <general class="panda" name="actuator5" joint="joint5" forcerange="-12 12"/>
    <general class="panda" name="actuator6" joint="joint6" forcerange="-12 12" ctrlrange="-0.0175 3.7525"/>
    <general class="panda" name="actuator7" joint="joint7" forcerange="-12 12"/>
  </actuator>

  <keyframe>
    <key name="home" qpos="0 0 0 -1.57079 0 1.57079 -0.7853 0 0 0 0"/>
  </keyframe>
</mujoco>

Could someone please help me with this issue? Thanks

naijekux commented 4 months ago

I also print some variables inside of step(),

def step(self, state: State, action: jax.Array) -> State:
    """Runs one timestep of the environment's dynamics."""
    print('step begins')
    assert state.pipeline_state is not None
    x_i = state.pipeline_state.x.vmap().do(
        base.Transform.create(pos=self.sys.link.inertia.transform.pos)
    )
    vec_1 = x_i.pos[self._object_idx] - x_i.pos[self._hand_idx]
    vec_2 = x_i.pos[self._object_idx] - x_i.pos[self._goal_idx]
    reward_near = -math.safe_norm(vec_1)
    reward_dist = -math.safe_norm(vec_2)
    reward_ctrl = -jp.square(action).sum()
    reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near

    pipeline_state = self.pipeline_step(state.pipeline_state, action)

    obs = self._get_obs(pipeline_state)

    state.metrics.update(
        reward_dist = reward_dist,
        reward_ctrl = reward_ctrl,
        reward_near = reward_near,         
    )
    print('step ends')
    debug.print("vec_1:{x}",x=vec_1)
    debug.print("reward_near:{x}",x=reward_near)
    debug.print("q:{x}",x=pipeline_state.q[:7])
    debug.print("goal_pos:{x}",x=x_i.pos[self._goal_idx])
    return state.replace(pipeline_state=pipeline_state, obs=obs, reward=reward)

The outputs are as below,

reset starts
reset ends
reset starts
reset ends
step begins
step ends
goal_pos:[ 0.65  -0.15  -0.324]
vec_1:[-0.351  0.354 -1.046]
vec_1:[-0.351  0.354 -1.046]
vec_1:[-0.351  0.354 -1.046]
vec_1:[-0.351  0.354 -1.046]
vec_1:[-0.351  0.354 -1.046]
vec_1:[-0.351  0.354 -1.046]
vec_1:[-0.351  0.354 -1.046]
reward_near:-1.1593074798583984
reward_near:-1.1593074798583984
reward_near:-1.1593074798583984
reward_near:-1.1593074798583984
...
...
...
goal_pos:[ 0.65  -0.15  -0.324]
goal_pos:[ 0.65  -0.15  -0.324]
goal_pos:[ 0.65  -0.15  -0.324]
goal_pos:[ 0.65  -0.15  -0.324]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]

the only variable outputing NaN values at the beginning is pipeline_state.q[:7], which is returned from pipeline_state = self.pipeline_step(state.pipeline_state, action), while the other variables were assigned before calling pipeline_step(). That means that the NaNs error happens in pipeline_step().

Then test if the NaN values are raised because of the input parameter 'action'. At this way an array of ones is implemented as action,

# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
for i in range(10):
  ctrl = -0.1 * jp.ones(env.sys.nu)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

media.show_video(env.render(rollout), fps=1.0 / env.dt)

NaNs values are still there in the printed output,

reset
step begins
step ends
vec_1:[-0.362  0.35  -1.001]
reward_near:-1.1205109357833862
q:[nan nan nan nan nan nan nan]
goal_pos:[ 0.65  -0.15  -0.324]
vec_1:[nan nan nan]
reward_near:nan
q:[nan nan nan nan nan nan nan]
goal_pos:[nan nan nan]
vec_1:[nan nan nan]
reward_near:nan
q:[nan nan nan nan nan nan nan]
goal_pos:[nan nan nan]
vec_1:[nan nan nan]
reward_near:nan
q:[nan nan nan nan nan nan nan]
goal_pos:[nan nan nan]
vec_1:[nan nan nan]
reward_near:nan
q:[nan nan nan nan nan nan nan]
goal_pos:[nan nan nan]
vec_1:[nan nan nan]
reward_near:nan
...
vec_1:[nan nan nan]
reward_near:nan
q:[nan nan nan nan nan nan nan]
goal_pos:[nan nan nan]

So it excludes that the input parameter action causes the issue inside of pipeline_step()

btaba commented 3 months ago

Closing this issue since a duplicate one was opened https://github.com/google-deepmind/mujoco/issues/1484