google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
7.85k stars 786 forks source link

[MJX] Function pipeline_step() returns data raising NaN values error #1484

Closed naijekux closed 5 months ago

naijekux commented 6 months ago

Hi,

I'm looking for some help with MJX.

The function pipeline_step() ('mjx' backend) inside of step() of gymnasium env framework doesn't work well and returns a new data (pipeline_state) consisting of NaN values.

I found the issue by printing some variables inside of step(), in order to test if every variables are valued well.

def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    print('step begins')
    data0 = state.pipeline_state

    vec_1 = data0.xpos[-2,:] - data0.xpos[-3,:]
    vec_2 = data0.xpos[-2,:] - data0.xpos[-1,:]
    reward_near = -jp.linalg.norm(vec_1)
    reward_dist = -jp.linalg.norm(vec_2)
    reward_ctrl = -jp.square(action).sum()
    reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near

    data = self.pipeline_step(state.pipeline_state, action)

    obs = self._get_obs(data)

    state.metrics.update(
        reward_dist = reward_dist,
        reward_ctrl = reward_ctrl,
        reward_near = reward_near,         
    )
    print('step ends')
    debug.print("vec_1:{x}",x=vec_1)
    debug.print("reward_near:{x}",x=reward_near)
    debug.print("q:{x}",x=data.qpos[:7])
    debug.print("goal_pos:{x}",x=data0.xpos[-1,:])
    return state.replace(pipeline_state=data, obs=obs, reward=reward)

and the output is like,

vec_1:[-0.362  0.35  -1.001]
vec_1:[-0.362  0.35  -1.001]
vec_1:[-0.362  0.35  -1.001]
...
goal_pos:[ 0.65  -0.15  -0.324]
goal_pos:[ 0.65  -0.15  -0.324]
goal_pos:[ 0.65  -0.15  -0.324]
...
reward_near:-1.1205109357833862
reward_near:-1.1205109357833862
reward_near:-1.1205109357833862
...
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]
q:[nan nan nan nan nan nan nan]

The only variable outputing NaN values at the beginning is pipeline_state.q[:7], which is returned from pipeline_state = self.pipeline_step(state.pipeline_state, action), while the other variables were assigned before calling pipeline_step(). That means that the NaNs error happens in pipeline_step().

I suppose that the two input parameters _pipelinestate and action have nothing with the problem mentioned above, but the MCJF model from a XML file causes the problem for it.

The same issue happens too when using brax 'generalized' instead of 'mjx' as backend. Details for it could be found here posted also in Brax's github

Here is the model created by a XML file:

<mujoco model="panda nohand">
  <compiler angle="radian" meshdir="assets" autolimits="true"/>

  <option timestep="0.01" gravity="0 0 0"/>

  <default>
    <default class="panda">
      <material specular="0.5" shininess="0.25"/>
      <joint armature="0.1" damping="1" axis="0 0 1" range="-2.8973 2.8973"/>
      <general dyntype="none" biastype="affine" ctrlrange="-2.8973 2.8973" forcerange="-87 87" ctrllimited="true"/>

      <default class="visual">
        <geom type="mesh" contype="0" conaffinity="0" group="2"/>
      </default>
      <default class="collision">
        <geom type="mesh" contype="0" conaffinity="1" group="3"/>
      </default>
      <site size="0.001" rgba="0.5 0.5 0.5 0.3" group="4"/>
    </default>
  </default>

  <asset>
    <material class="panda" name="white" rgba="1 1 1 1"/>
    <material class="panda" name="off_white" rgba="0.901961 0.921569 0.929412 1"/>
    <material class="panda" name="dark_grey" rgba="0.25 0.25 0.25 1"/>
    <material class="panda" name="green" rgba="0 1 0 1"/>
    <material class="panda" name="light_blue" rgba="0.039216 0.541176 0.780392 1"/>
    <!--
    <material class="panda" name="red" rgba="1 0 0 1"/>
    -->

    <!-- Collision meshes -->
    <mesh name="link0_c" file="link0.stl"/>
    <mesh name="link1_c" file="link1.stl"/>
    <mesh name="link2_c" file="link2.stl"/>
    <mesh name="link3_c" file="link3.stl"/>
    <mesh name="link4_c" file="link4.stl"/>
    <mesh name="link5_c0" file="link5_collision_0.obj"/>
    <mesh name="link5_c1" file="link5_collision_1.obj"/>
    <mesh name="link5_c2" file="link5_collision_2.obj"/>
    <mesh name="link6_c" file="link6.stl"/>
    <mesh name="link7_c" file="link7.stl"/>

    <!-- Visual meshes -->
    <mesh file="link0_0.obj"/>
    <mesh file="link0_1.obj"/>
    <mesh file="link0_2.obj"/>
    <mesh file="link0_3.obj"/>
    <mesh file="link0_4.obj"/>
    <mesh file="link0_5.obj"/>
    <mesh file="link0_7.obj"/>
    <mesh file="link0_8.obj"/>
    <mesh file="link0_9.obj"/>
    <mesh file="link0_10.obj"/>
    <mesh file="link0_11.obj"/>
    <mesh file="link1.obj"/>
    <mesh file="link2.obj"/>
    <mesh file="link3_0.obj"/>
    <mesh file="link3_1.obj"/>
    <mesh file="link3_2.obj"/>
    <mesh file="link3_3.obj"/>
    <mesh file="link4_0.obj"/>
    <mesh file="link4_1.obj"/>
    <mesh file="link4_2.obj"/>
    <mesh file="link4_3.obj"/>
    <mesh file="link5_0.obj"/>
    <mesh file="link5_1.obj"/>
    <mesh file="link5_2.obj"/>
    <mesh file="link6_0.obj"/>
    <mesh file="link6_1.obj"/>
    <mesh file="link6_2.obj"/>
    <mesh file="link6_3.obj"/>
    <mesh file="link6_4.obj"/>
    <mesh file="link6_5.obj"/>
    <mesh file="link6_6.obj"/>
    <mesh file="link6_7.obj"/>
    <mesh file="link6_8.obj"/>
    <mesh file="link6_9.obj"/>
    <mesh file="link6_10.obj"/>
    <mesh file="link6_11.obj"/>
    <mesh file="link6_12.obj"/>
    <mesh file="link6_13.obj"/>
    <mesh file="link6_14.obj"/>
    <mesh file="link6_15.obj"/>
    <mesh file="link6_16.obj"/>
    <mesh file="link7_0.obj"/>
    <mesh file="link7_1.obj"/>
    <mesh file="link7_2.obj"/>
    <mesh file="link7_3.obj"/>
    <mesh file="link7_4.obj"/>
    <mesh file="link7_5.obj"/>
    <mesh file="link7_6.obj"/>
    <mesh file="link7_7.obj"/>
  </asset>

  <worldbody>
    <light name="top" pos="0 0 2" mode="trackcom"/>

    <geom name="table" type="plane" pos="0 0.5 -0.325" size="1 1 0.1" contype="1" conaffinity="1"/>

    <body name="link0" childclass="panda" pos="1.1 -0.5 -0.2" quat="0 0 0 1">
      <inertial mass="0.629769" pos="-0.041018 -0.00014 0.049974"
        fullinertia="0.00315 0.00388 0.004285 8.2904e-7 0.00015 8.2299e-6"/>
      <geom mesh="link0_0" material="off_white" class="visual"/>
      <geom mesh="link0_1" material="dark_grey" class="visual"/>
      <geom mesh="link0_2" material="off_white" class="visual"/>
      <geom mesh="link0_3" material="dark_grey" class="visual"/>
      <geom mesh="link0_4" material="off_white" class="visual"/>
      <geom mesh="link0_5" material="dark_grey" class="visual"/>
      <geom mesh="link0_7" material="white" class="visual"/>
      <geom mesh="link0_8" material="white" class="visual"/>
      <geom mesh="link0_9" material="dark_grey" class="visual"/>
      <geom mesh="link0_10" material="off_white" class="visual"/>
      <geom mesh="link0_11" material="white" class="visual"/>
      <geom mesh="link0_c" class="collision"/>

      <body name="link1" pos="0 0 0.333">
        <inertial mass="4.970684" pos="0.003875 0.002081 -0.04762"
          fullinertia="0.70337 0.70661 0.0091170 -0.00013900 0.0067720 0.019169"/>
        <joint name="joint1"/>
        <geom material="white" mesh="link1" class="visual"/>
        <geom mesh="link1_c" class="collision"/>

        <body name="link2" quat="1 -1 0 0">
          <inertial mass="0.646926" pos="-0.003141 -0.02872 0.003495"
            fullinertia="0.0079620 2.8110e-2 2.5995e-2 -3.925e-3 1.0254e-2 7.04e-4"/>
          <joint name="joint2" range="-1.7628 1.7628"/>
          <geom material="white" mesh="link2" class="visual"/>
          <geom mesh="link2_c" class="collision"/>

          <body name="link3" pos="0 -0.316 0" quat="1 1 0 0">
            <inertial mass="3.228604" pos="2.7518e-2 3.9252e-2 -6.6502e-2"
              fullinertia="3.7242e-2 3.6155e-2 1.083e-2 -4.761e-3 -1.1396e-2 -1.2805e-2"/>
            <joint name="joint3"/>
            <geom mesh="link3_0" material="white" class="visual"/>
            <geom mesh="link3_1" material="white" class="visual"/>
            <geom mesh="link3_2" material="white" class="visual"/>
            <geom mesh="link3_3" material="dark_grey" class="visual"/>
            <geom mesh="link3_c" class="collision"/>

            <body name="link4" pos="0.0825 0 0" quat="1 1 0 0">
              <inertial mass="3.587895" pos="-5.317e-2 1.04419e-1 2.7454e-2"
                fullinertia="2.5853e-2 1.9552e-2 2.8323e-2 7.796e-3 -1.332e-3 8.641e-3"/>
              <joint name="joint4" range="-3.0718 -0.0698"/>
              <geom mesh="link4_0" material="white" class="visual"/>
              <geom mesh="link4_1" material="white" class="visual"/>
              <geom mesh="link4_2" material="dark_grey" class="visual"/>
              <geom mesh="link4_3" material="white" class="visual"/>
              <geom mesh="link4_c" class="collision"/>

              <body name="link5" pos="-0.0825 0.384 0" quat="1 -1 0 0">
                <inertial mass="1.225946" pos="-1.1953e-2 4.1065e-2 -3.8437e-2"
                  fullinertia="3.5549e-2 2.9474e-2 8.627e-3 -2.117e-3 -4.037e-3 2.29e-4"/>
                <joint name="joint5"/>
                <geom mesh="link5_0" material="dark_grey" class="visual"/>
                <geom mesh="link5_1" material="white" class="visual"/>
                <geom mesh="link5_2" material="white" class="visual"/>
                <geom mesh="link5_c0" class="collision"/>
                <geom mesh="link5_c1" class="collision"/>
                <geom mesh="link5_c2" class="collision"/>

                <body name="link6" quat="1 1 0 0">
                  <inertial mass="1.666555" pos="6.0149e-2 -1.4117e-2 -1.0517e-2"
                    fullinertia="1.964e-3 4.354e-3 5.433e-3 1.09e-4 -1.158e-3 3.41e-4"/>
                  <joint name="joint6" range="-0.0175 3.7525"/>
                  <geom mesh="link6_0" material="off_white" class="visual"/>
                  <geom mesh="link6_1" material="white" class="visual"/>
                  <geom mesh="link6_2" material="dark_grey" class="visual"/>
                  <geom mesh="link6_3" material="white" class="visual"/>
                  <geom mesh="link6_4" material="white" class="visual"/>
                  <geom mesh="link6_5" material="white" class="visual"/>
                  <geom mesh="link6_6" material="white" class="visual"/>
                  <geom mesh="link6_7" material="light_blue" class="visual"/>
                  <geom mesh="link6_8" material="light_blue" class="visual"/>
                  <geom mesh="link6_9" material="dark_grey" class="visual"/>
                  <geom mesh="link6_10" material="dark_grey" class="visual"/>
                  <geom mesh="link6_11" material="white" class="visual"/>
                  <geom mesh="link6_12" material="green" class="visual"/>
                  <geom mesh="link6_13" material="white" class="visual"/>
                  <geom mesh="link6_14" material="dark_grey" class="visual"/>
                  <geom mesh="link6_15" material="dark_grey" class="visual"/>
                  <geom mesh="link6_16" material="white" class="visual"/>
                  <geom mesh="link6_c" class="collision"/>

                  <body name="link7" pos="0.088 0 0" quat="1 1 0 0">
                    <inertial mass="7.35522e-01" pos="1.0517e-2 -4.252e-3 6.1597e-2"
                      fullinertia="1.2516e-2 1.0027e-2 4.815e-3 -4.28e-4 -1.196e-3 -7.41e-4"/>
                    <joint name="joint7"/>
                    <geom mesh="link7_0" material="white" class="visual"/>
                    <geom mesh="link7_1" material="dark_grey" class="visual"/>
                    <geom mesh="link7_2" material="dark_grey" class="visual"/>
                    <geom mesh="link7_3" material="dark_grey" class="visual"/>
                    <geom mesh="link7_4" material="dark_grey" class="visual"/>
                    <geom mesh="link7_5" material="dark_grey" class="visual"/>
                    <geom mesh="link7_6" material="dark_grey" class="visual"/>
                    <geom mesh="link7_7" material="white" class="visual"/>
                    <geom mesh="link7_c" class="collision"/>

                    <body name="attachment" pos="0 0 0.107" quat="0.3826834 0 0 0.9238795">
                      <site name="attachment_site"/>
                    </body>
                  </body>
                </body>
              </body>
            </body>
          </body>
        </body>
      </body>
    </body>

    <body name="object" pos="0.65 -0.15 -0.275" >
      <geom rgba="1 1 1 0" type="sphere" size="0.05 0.05 0.05" density="1000" conaffinity="0"/>
      <geom rgba="1 1 1 1" type="box" size="0.05 0.05 0.05" density="1000" contype="1" conaffinity="0"/>
      <joint name="obj_slidey" type="slide" pos="0 0 0" axis="0 1 0" range="-10.3213 10.3" damping="0.5"/>
      <joint name="obj_slidex" type="slide" pos="0 0 0" axis="1 0 0" range="-10.3213 10.3" damping="0.5"/>
    </body>
    <body name="goal" pos="0.65 -0.15 -0.3240">
      <geom rgba="1 0 0 1" type="box" size="0.08 0.08 0.001" density='0.00001' contype="0" conaffinity="0"/>
      <joint name="goal_slidey" type="slide" pos="0 0 0" axis="0 1 0" range="-10.3213 10.3" damping="0.5"/>
      <joint name="goal_slidex" type="slide" pos="0 0 0" axis="1 0 0" range="-10.3213 10.3" damping="0.5"/>
    </body>
  </worldbody>

  <actuator>
    <general class="panda" name="actuator1" joint="joint1"/>
    <general class="panda" name="actuator2" joint="joint2" ctrlrange="-1.7628 1.7628"/>
    <general class="panda" name="actuator3" joint="joint3" />
    <general class="panda" name="actuator4" joint="joint4" ctrlrange="-3.0718 -0.0698"/>
    <general class="panda" name="actuator5" joint="joint5" forcerange="-12 12"/>
    <general class="panda" name="actuator6" joint="joint6" forcerange="-12 12" ctrlrange="-0.0175 3.7525"/>
    <general class="panda" name="actuator7" joint="joint7" forcerange="-12 12"/>
  </actuator>

  <keyframe>
    <key name="home" qpos="0 0 0 -1.57079 0 1.57079 -0.7853 0 0 0 0"/>
  </keyframe>
</mujoco>

Could someone please help me with this issue? Thanks

naijekux commented 6 months ago

I suspect the NaN value error is raised from mjx.step() inside of brax.mjx.pipeline.step()

def step(
    sys: System, state: State, act: jax.Array, unused_debug: bool = False
) -> State:

  data = state.replace(ctrl=act)
  data = mjx.step(sys, data)

  q, qd = data.qpos, data.qvel
  x = Transform(pos=data.xpos[1:], rot=data.xquat[1:])
  cvel = Motion(vel=data.cvel[1:, 3:], ang=data.cvel[1:, :3])
  offset = data.xpos[1:, :] - data.subtree_com[sys.body_rootid[1:]]
  offset = Transform.create(pos=offset)
  xd = offset.vmap().do(cvel)

  data = _reformat_contact(sys, data)
  return data.replace(q=q, qd=qd, x=x, xd=xd)

for the reason that there's no problem by running brax.mjx.pipeline.init(), when the environment is being reset.

def init(
    sys: System, q: jax.Array, qd: jax.Array, unused_debug: bool = False
) -> State:
  data = mjx.make_data(sys)
  data = data.replace(qpos=q, qvel=qd)
  data = mjx.forward(sys, data)

  q, qd = data.qpos, data.qvel
  x = Transform(pos=data.xpos[1:], rot=data.xquat[1:])
  cvel = Motion(vel=data.cvel[1:, 3:], ang=data.cvel[1:, :3])
  offset = data.xpos[1:, :] - data.subtree_com[sys.body_rootid[1:]]
  offset = Transform.create(pos=offset)
  xd = offset.vmap().do(cvel)

  data = _reformat_contact(sys, data)
  return State(q=q, qd=qd, x=x, xd=xd, **data.__dict__)

Apparently, the most codes in step() is the same as that in init() apart from the line data = mjx.step(sys, data). Now that the codes in init() generate no errors, NaN values come from calling mjx.step().

btaba commented 6 months ago

Hi @naijekux , can you do:

from jax import config
config.update("jax_debug_nans", True)

and report back? Just after a cursory scan, your timestep is quite high at 0.01s. Is the simulation stable when you load it in simulate?

naijekux commented 6 months ago

Thanks a lot @btaba. The report to dectect NaNs are below, these tracebacks are generated by runing training function train_fn = functools.partial().

---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
    [... skipping hidden 1 frame]

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/profiler.py:336, in annotate_function.<locals>.wrapper(*args, **kwargs)
    [335] with TraceAnnotation(name, **decorator_kwargs):
--> [336]  return func(*args, **kwargs)
    [337]return wrapper

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1213, in ExecuteReplicated.__call__(self, *args)
   [1212] for arrays in out_arrays:
-> [1213]   dispatch.check_special(self.name, arrays)
   [1214] return self.out_handler(out_arrays)

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/dispatch.py:314, in check_special(name, bufs)
    [313] for buf in bufs:
--> [314]   _check_special(name, buf.dtype, buf)

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/dispatch.py:319, in _check_special(name, dtype, buf)
    [318] if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
--> [319]   raise FloatingPointError(f"invalid value (nan) encountered in {name}")
    [320] if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):

FloatingPointError: invalid value (nan) encountered in jit(generate_eval_unroll)

During handling of the above exception, another exception occurred:

FloatingPointError                        Traceback (most recent call last)
    [... skipping hidden 1 frame]

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/profiler.py:336, in annotate_function.<locals>.wrapper(*args, **kwargs)
    [335] with TraceAnnotation(name, **decorator_kwargs):
--> [336]   return func(*args, **kwargs)
    [337] return wrapper

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1213, in ExecuteReplicated.__call__(self, *args)
   [1212] for arrays in out_arrays:
-> [1213]   dispatch.check_special(self.name, arrays)
   [1214] return self.out_handler(out_arrays)

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/dispatch.py:314, in check_special(name, bufs)
    [313] for buf in bufs:
--> [314]   _check_special(name, buf.dtype, buf)

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/dispatch.py:319, in _check_special(name, dtype, buf)
    [318] if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
--> [319]   raise FloatingPointError(f"invalid value (nan) encountered in {name}")
    [320] if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):

FloatingPointError: invalid value (nan) encountered in jit(scan)

During handling of the above exception, another exception occurred:

JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel_launcher.py:18
     [16] from ipykernel import kernelapp as app
---> [18] app.launch_new_instance()

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/traitlets/config/application.py:1075, in launch_instance()
   [1074] app.initialize(argv)
-> [1075] app.start()

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/kernelapp.py:739, in start()
    [738] try:
--> [739]     self.io_loop.start()
    [740] except KeyboardInterrupt:

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/tornado/platform/asyncio.py:195, in start()
    [194] def start(self) -> None:
--> [195]     self.asyncio_loop.run_forever()

File ~/anaconda3/envs/mjx/lib/python3.12/asyncio/base_events.py:638], in run_forever()
    [637] while True:
--> [638]    self._run_once()
    [639]     if self._stopping:

File ~/anaconda3/envs/mjx/lib/python3.12/asyncio/base_events.py:1971, in _run_once()
   [1970]     else:
-> [1971]         handle._run()
   [1972] handle = None

File ~/anaconda3/envs/mjx/lib/python3.12/asyncio/events.py:84, in _run()
     [83] try:
---> [84]     self._context.run(self._callback, *self._args)
     [85] except (SystemExit, KeyboardInterrupt):

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/kernelbase.py:542, in dispatch_queue()
    [541] try:
--> [542]     await self.process_one()
    [543] except Exception:

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/kernelbase.py:531, in process_one()
    [530]         return
--> [531] await dispatch(*args)

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell()
    [436]     if inspect.isawaitable(result):
--> [437]         await result
    [438] except Exception:

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/ipkernel.py:359, in execute_request()
    [358]self._associate_new_top_level_threads_with(parent_header)
--> [359] await super().execute_request(stream, ident, parent)

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/kernelbase.py:775, in execute_request()
    [774] if inspect.isawaitable(reply_content):
--> [775]     reply_content = await reply_content
    [777] # Flush output before sending the reply.

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/ipkernel.py:446, in do_execute()
    [445] if accepts_params["cell_id"]:
--> [446]     res = shell.run_cell(
    [447]         code,
    [448]         store_history=store_history,
    [449]         silent=silent,
    [450]         cell_id=cell_id,
    [451]     )
    [452] else:

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/zmqshell.py:549, in run_cell()
    [548] self._last_traceback = None
--> [549] return super().run_cell(*args, **kwargs)

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3051], in run_cell()
   [3050] try:
-> [3051]     result = self._run_cell(
   [3052]         raw_cell, store_history, silent, shell_futures, cell_id
   [3053]     )
   [3054] finally:

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3106, in _run_cell()
   [3105] try:
-> [3106]     result = runner(coro)
   [3107] except BaseException as e:

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
    [128] try:
--> [129]     coro.send(None)
    [130] except StopIteration as exc:

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3311, in run_cell_async()
   [3308] interactivity = "none" if silent else self.ast_node_interactivity
-> [3311] has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   [3312]        interactivity=interactivity, compiler=compiler, result=result)
   [3314] self.last_execution_succeeded = not has_raised

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3493, in run_ast_nodes()
   [3492]     asy = compare(code)
-> [3493] if await self.run_code(code, result, async_=asy):
   [3494]     return True

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3553, in run_code()
   [3552]     else:
-> [3553]         exec(code_obj, self.user_global_ns, self.user_ns)
   [3554] finally:
   [3555]     # Reset our crash handler in place

Cell In[30], line 32
     [30]   plt.show()
---> [32] make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
     [34] print(f'time to jit: {times[1] - times[0]}')

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/agents/ppo/train.py:405, in train()
    [404] if process_id == 0 and num_evals > 1:
--> [405]   metrics = evaluator.run_evaluation(
    [406]       _unpmap(
    [407]           (training_state.normalizer_params, training_state.params.policy)),
    [408]       training_metrics={})
    [409]   logging.info(metrics)

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/acting.py:125, in run_evaluation()
    [124] t = time.time()
--> [125] eval_state = self._generate_eval_unroll(policy_params, unroll_key)
    [126] eval_metrics = eval_state.info['eval_metrics']

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/acting.py:107, in generate_eval_unroll()
    [106] eval_first_state = eval_env.reset(reset_keys)
--> [107] return generate_unroll(
    [108]     eval_env,
    [109]     eval_first_state,
    [110]     eval_policy_fn(policy_params),
    [111]     key,
    [112]     unroll_length=episode_length [/](https://file+.vscode-resource.vscode-cdn.net/)[/](https://file+.vscode-resource.vscode-cdn.net/) action_repeat)[0]

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/acting.py:75], in generate_unroll()
     [73]   return (nstate, next_key), transition
---> [75] (final_state, _), data = jax.lax.scan(
     [76]     f, (env_state, key), (), length=unroll_length)
     [77] return final_state, data

JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/google/jax.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

FloatingPointError                        Traceback (most recent call last)
Cell In[30], line 32
     [28]   plt.errorbar(
     [29]       x_data, y_data, yerr=ydataerr)
     [30]   plt.show()
---> [32] make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
     [34] print(f'time to jit: {times[1] - times[0]}')
     [35] print(f'time to train: {times[-1] - times[1]}')

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/agents/ppo/train.py:405, in train(environment, num_timesteps, episode_length, action_repeat, num_envs, max_devices_per_host, num_eval_envs, learning_rate, entropy_cost, discounting, seed, unroll_length, batch_size, num_minibatches, num_updates_per_batch, num_evals, num_resets_per_eval, normalize_observations, reward_scaling, clipping_epsilon, gae_lambda, deterministic_eval, network_factory, progress_fn, normalize_advantage, eval_env, policy_params_fn, randomization_fn)
    [403] metrics = {}
    [404] if process_id == 0 and num_evals > 1:
--> [405]   metrics = evaluator.run_evaluation(
    [406]       _unpmap(
    [407]           (training_state.normalizer_params, training_state.params.policy)),
    [408]       training_metrics={})
    [409]   logging.info(metrics)
    [410]   progress_fn(0, metrics)

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/acting.py:125, in Evaluator.run_evaluation(self, policy_params, training_metrics, aggregate_episodes)
    [122] self._key, unroll_key = jax.random.split(self._key)
    [124] t = time.time()
--> [125] eval_state = self._generate_eval_unroll(policy_params, unroll_key)
    [126] eval_metrics = eval_state.info['eval_metrics']
    [127] eval_metrics.active_episodes.block_until_ready()

    [... skipping hidden 24 frame]

File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/pjit.py:1372, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, keep_unused, inline, *args)
   [1355] # If control reaches this line, we got a NaN on the output of `compiled`
   [1356] # but not `fun.call_wrapped` on the same arguments. Let's tell the user.
   [1357] msg = (f"{str(e)}. Because "
   [1358]        "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
   [1359]        "de-optimized function (i.e., the function as if the `jit` "
   (...)
   [1370]        "If you see this error, consider opening a bug report at "
   [1371]        "https://github.com/google/jax.")
-> [1372] raise FloatingPointError(msg)

FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/google/jax.

The simulation in simulate runs well, can be controlled through setting the values of actuator. Screenshot from 2024-03-12 14-55-06

Thanks in advance.

btaba commented 6 months ago

Ok, from the traceback, it looks like the non-jitted function does not produce a NaN, but the jitted function does. This often occurs if the simulation is sensitive to numerical precision differences between the jitted and non-jitted functions. Can you try: [1] lowering the timestep (0.01 is high), [2] try running with jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH) to see if any of those make a difference. If not, please send a complete example we can use to repro the NaN (folder with XML + meshes)

naijekux commented 6 months ago

Hi @btaba, thanks a lot for your help these days. I had actually found the solution for it days ago. It's caused by the initial velocity on the cube to be pushed. So NaN values wouldn't raise anymore if no initial velocity is imposed on the cube, although I dont't know the reason for this issue. But the training speed in my case is way way way lower than the mjx colab humanoid example - more than 40 min just for 5e6 total steps during my training. The folder with code and XML+meshes for my environment can be found here under my repo in github. Thanks.

btaba commented 5 months ago

See https://mujoco.readthedocs.io/en/stable/mjx.html#performance-tuning if you're seeing slow training

naijekux commented 5 months ago

Hi @btaba, it's quite helpful, thanks for your help these days.