Closed naijekux closed 5 months ago
I suspect the NaN value error is raised from mjx.step()
inside of brax.mjx.pipeline.step()
def step(
sys: System, state: State, act: jax.Array, unused_debug: bool = False
) -> State:
data = state.replace(ctrl=act)
data = mjx.step(sys, data)
q, qd = data.qpos, data.qvel
x = Transform(pos=data.xpos[1:], rot=data.xquat[1:])
cvel = Motion(vel=data.cvel[1:, 3:], ang=data.cvel[1:, :3])
offset = data.xpos[1:, :] - data.subtree_com[sys.body_rootid[1:]]
offset = Transform.create(pos=offset)
xd = offset.vmap().do(cvel)
data = _reformat_contact(sys, data)
return data.replace(q=q, qd=qd, x=x, xd=xd)
for the reason that there's no problem by running brax.mjx.pipeline.init()
, when the environment is being reset.
def init(
sys: System, q: jax.Array, qd: jax.Array, unused_debug: bool = False
) -> State:
data = mjx.make_data(sys)
data = data.replace(qpos=q, qvel=qd)
data = mjx.forward(sys, data)
q, qd = data.qpos, data.qvel
x = Transform(pos=data.xpos[1:], rot=data.xquat[1:])
cvel = Motion(vel=data.cvel[1:, 3:], ang=data.cvel[1:, :3])
offset = data.xpos[1:, :] - data.subtree_com[sys.body_rootid[1:]]
offset = Transform.create(pos=offset)
xd = offset.vmap().do(cvel)
data = _reformat_contact(sys, data)
return State(q=q, qd=qd, x=x, xd=xd, **data.__dict__)
Apparently, the most codes in step()
is the same as that in init()
apart from the line data = mjx.step(sys, data)
. Now that the codes in init()
generate no errors, NaN values come from calling mjx.step()
.
Hi @naijekux , can you do:
from jax import config
config.update("jax_debug_nans", True)
and report back? Just after a cursory scan, your timestep is quite high at 0.01s. Is the simulation stable when you load it in simulate
?
Thanks a lot @btaba.
The report to dectect NaNs are below, these tracebacks are generated by runing training function train_fn = functools.partial()
.
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
[... skipping hidden 1 frame]
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/profiler.py:336, in annotate_function.<locals>.wrapper(*args, **kwargs)
[335] with TraceAnnotation(name, **decorator_kwargs):
--> [336] return func(*args, **kwargs)
[337]return wrapper
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1213, in ExecuteReplicated.__call__(self, *args)
[1212] for arrays in out_arrays:
-> [1213] dispatch.check_special(self.name, arrays)
[1214] return self.out_handler(out_arrays)
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/dispatch.py:314, in check_special(name, bufs)
[313] for buf in bufs:
--> [314] _check_special(name, buf.dtype, buf)
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/dispatch.py:319, in _check_special(name, dtype, buf)
[318] if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
--> [319] raise FloatingPointError(f"invalid value (nan) encountered in {name}")
[320] if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
FloatingPointError: invalid value (nan) encountered in jit(generate_eval_unroll)
During handling of the above exception, another exception occurred:
FloatingPointError Traceback (most recent call last)
[... skipping hidden 1 frame]
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/profiler.py:336, in annotate_function.<locals>.wrapper(*args, **kwargs)
[335] with TraceAnnotation(name, **decorator_kwargs):
--> [336] return func(*args, **kwargs)
[337] return wrapper
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1213, in ExecuteReplicated.__call__(self, *args)
[1212] for arrays in out_arrays:
-> [1213] dispatch.check_special(self.name, arrays)
[1214] return self.out_handler(out_arrays)
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/dispatch.py:314, in check_special(name, bufs)
[313] for buf in bufs:
--> [314] _check_special(name, buf.dtype, buf)
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/dispatch.py:319, in _check_special(name, dtype, buf)
[318] if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
--> [319] raise FloatingPointError(f"invalid value (nan) encountered in {name}")
[320] if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
FloatingPointError: invalid value (nan) encountered in jit(scan)
During handling of the above exception, another exception occurred:
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()
File <frozen runpy>:88, in _run_code()
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel_launcher.py:18
[16] from ipykernel import kernelapp as app
---> [18] app.launch_new_instance()
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/traitlets/config/application.py:1075, in launch_instance()
[1074] app.initialize(argv)
-> [1075] app.start()
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/kernelapp.py:739, in start()
[738] try:
--> [739] self.io_loop.start()
[740] except KeyboardInterrupt:
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/tornado/platform/asyncio.py:195, in start()
[194] def start(self) -> None:
--> [195] self.asyncio_loop.run_forever()
File ~/anaconda3/envs/mjx/lib/python3.12/asyncio/base_events.py:638], in run_forever()
[637] while True:
--> [638] self._run_once()
[639] if self._stopping:
File ~/anaconda3/envs/mjx/lib/python3.12/asyncio/base_events.py:1971, in _run_once()
[1970] else:
-> [1971] handle._run()
[1972] handle = None
File ~/anaconda3/envs/mjx/lib/python3.12/asyncio/events.py:84, in _run()
[83] try:
---> [84] self._context.run(self._callback, *self._args)
[85] except (SystemExit, KeyboardInterrupt):
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/kernelbase.py:542, in dispatch_queue()
[541] try:
--> [542] await self.process_one()
[543] except Exception:
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/kernelbase.py:531, in process_one()
[530] return
--> [531] await dispatch(*args)
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell()
[436] if inspect.isawaitable(result):
--> [437] await result
[438] except Exception:
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/ipkernel.py:359, in execute_request()
[358]self._associate_new_top_level_threads_with(parent_header)
--> [359] await super().execute_request(stream, ident, parent)
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/kernelbase.py:775, in execute_request()
[774] if inspect.isawaitable(reply_content):
--> [775] reply_content = await reply_content
[777] # Flush output before sending the reply.
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/ipkernel.py:446, in do_execute()
[445] if accepts_params["cell_id"]:
--> [446] res = shell.run_cell(
[447] code,
[448] store_history=store_history,
[449] silent=silent,
[450] cell_id=cell_id,
[451] )
[452] else:
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/ipykernel/zmqshell.py:549, in run_cell()
[548] self._last_traceback = None
--> [549] return super().run_cell(*args, **kwargs)
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3051], in run_cell()
[3050] try:
-> [3051] result = self._run_cell(
[3052] raw_cell, store_history, silent, shell_futures, cell_id
[3053] )
[3054] finally:
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3106, in _run_cell()
[3105] try:
-> [3106] result = runner(coro)
[3107] except BaseException as e:
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
[128] try:
--> [129] coro.send(None)
[130] except StopIteration as exc:
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3311, in run_cell_async()
[3308] interactivity = "none" if silent else self.ast_node_interactivity
-> [3311] has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
[3312] interactivity=interactivity, compiler=compiler, result=result)
[3314] self.last_execution_succeeded = not has_raised
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3493, in run_ast_nodes()
[3492] asy = compare(code)
-> [3493] if await self.run_code(code, result, async_=asy):
[3494] return True
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3553, in run_code()
[3552] else:
-> [3553] exec(code_obj, self.user_global_ns, self.user_ns)
[3554] finally:
[3555] # Reset our crash handler in place
Cell In[30], line 32
[30] plt.show()
---> [32] make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
[34] print(f'time to jit: {times[1] - times[0]}')
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/agents/ppo/train.py:405, in train()
[404] if process_id == 0 and num_evals > 1:
--> [405] metrics = evaluator.run_evaluation(
[406] _unpmap(
[407] (training_state.normalizer_params, training_state.params.policy)),
[408] training_metrics={})
[409] logging.info(metrics)
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/acting.py:125, in run_evaluation()
[124] t = time.time()
--> [125] eval_state = self._generate_eval_unroll(policy_params, unroll_key)
[126] eval_metrics = eval_state.info['eval_metrics']
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/acting.py:107, in generate_eval_unroll()
[106] eval_first_state = eval_env.reset(reset_keys)
--> [107] return generate_unroll(
[108] eval_env,
[109] eval_first_state,
[110] eval_policy_fn(policy_params),
[111] key,
[112] unroll_length=episode_length [/](https://file+.vscode-resource.vscode-cdn.net/)[/](https://file+.vscode-resource.vscode-cdn.net/) action_repeat)[0]
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/acting.py:75], in generate_unroll()
[73] return (nstate, next_key), transition
---> [75] (final_state, _), data = jax.lax.scan(
[76] f, (env_state, key), (), length=unroll_length)
[77] return final_state, data
JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations.
If you see this error, consider opening a bug report at https://github.com/google/jax.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
FloatingPointError Traceback (most recent call last)
Cell In[30], line 32
[28] plt.errorbar(
[29] x_data, y_data, yerr=ydataerr)
[30] plt.show()
---> [32] make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
[34] print(f'time to jit: {times[1] - times[0]}')
[35] print(f'time to train: {times[-1] - times[1]}')
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/agents/ppo/train.py:405, in train(environment, num_timesteps, episode_length, action_repeat, num_envs, max_devices_per_host, num_eval_envs, learning_rate, entropy_cost, discounting, seed, unroll_length, batch_size, num_minibatches, num_updates_per_batch, num_evals, num_resets_per_eval, normalize_observations, reward_scaling, clipping_epsilon, gae_lambda, deterministic_eval, network_factory, progress_fn, normalize_advantage, eval_env, policy_params_fn, randomization_fn)
[403] metrics = {}
[404] if process_id == 0 and num_evals > 1:
--> [405] metrics = evaluator.run_evaluation(
[406] _unpmap(
[407] (training_state.normalizer_params, training_state.params.policy)),
[408] training_metrics={})
[409] logging.info(metrics)
[410] progress_fn(0, metrics)
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/brax/training/acting.py:125, in Evaluator.run_evaluation(self, policy_params, training_metrics, aggregate_episodes)
[122] self._key, unroll_key = jax.random.split(self._key)
[124] t = time.time()
--> [125] eval_state = self._generate_eval_unroll(policy_params, unroll_key)
[126] eval_metrics = eval_state.info['eval_metrics']
[127] eval_metrics.active_episodes.block_until_ready()
[... skipping hidden 24 frame]
File ~/anaconda3/envs/mjx/lib/python3.12/site-packages/jax/_src/pjit.py:1372, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, keep_unused, inline, *args)
[1355] # If control reaches this line, we got a NaN on the output of `compiled`
[1356] # but not `fun.call_wrapped` on the same arguments. Let's tell the user.
[1357] msg = (f"{str(e)}. Because "
[1358] "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
[1359] "de-optimized function (i.e., the function as if the `jit` "
(...)
[1370] "If you see this error, consider opening a bug report at "
[1371] "https://github.com/google/jax.")
-> [1372] raise FloatingPointError(msg)
FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations.
If you see this error, consider opening a bug report at https://github.com/google/jax.
The simulation in simulate runs well, can be controlled through setting the values of actuator.
Thanks in advance.
Ok, from the traceback, it looks like the non-jitted function does not produce a NaN, but the jitted function does. This often occurs if the simulation is sensitive to numerical precision differences between the jitted and non-jitted functions. Can you try: [1] lowering the timestep (0.01 is high), [2] try running with jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH)
to see if any of those make a difference. If not, please send a complete example we can use to repro the NaN (folder with XML + meshes)
Hi @btaba, thanks a lot for your help these days. I had actually found the solution for it days ago. It's caused by the initial velocity on the cube to be pushed. So NaN values wouldn't raise anymore if no initial velocity is imposed on the cube, although I dont't know the reason for this issue. But the training speed in my case is way way way lower than the mjx colab humanoid example - more than 40 min just for 5e6 total steps during my training. The folder with code and XML+meshes for my environment can be found here under my repo in github. Thanks.
See https://mujoco.readthedocs.io/en/stable/mjx.html#performance-tuning if you're seeing slow training
Hi @btaba, it's quite helpful, thanks for your help these days.
Hi,
I'm looking for some help with MJX.
The function
pipeline_step()
('mjx' backend) inside ofstep()
of gymnasium env framework doesn't work well and returns a new data (pipeline_state) consisting of NaN values.I found the issue by printing some variables inside of
step()
, in order to test if every variables are valued well.and the output is like,
The only variable outputing NaN values at the beginning is
pipeline_state.q[:7]
, which is returned frompipeline_state = self.pipeline_step(state.pipeline_state, action)
, while the other variables were assigned before callingpipeline_step()
. That means that the NaNs error happens inpipeline_step()
.I suppose that the two input parameters _pipelinestate and action have nothing with the problem mentioned above, but the MCJF model from a XML file causes the problem for it.
The same issue happens too when using brax 'generalized' instead of 'mjx' as backend. Details for it could be found here posted also in Brax's github
Here is the model created by a XML file:
Could someone please help me with this issue? Thanks