google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.19k stars 241 forks source link

NaNs at Inference #501

Open willthibault opened 1 month ago

willthibault commented 1 month ago

Hello,

I have been encountering an issue where my training runs error free and learns well, but then control values of NaN are generated at inference when collecting a trajectory to make a video of the task.

I am currently using the following lines to improve the precision and debug NaNs:

import os
os.environ["JAX_TRACEBACK_FILTERING"] = "off"

jax.config.update('jax_debug_nans', True)
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH)

The error generated at inference from MuJoCo is:

WARNING: Nan, Inf or huge value in CTRL at ACTUATOR 0. The simulation is unstable. Time = 0.6000.

The error from the inference is the following:

/usr/local/lib/python3.9/dist-packages/jax/_src/interpreters/xla.py:155: RuntimeWarning: overflow encountered in cast
  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/pjit.py", line 1568, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/interpreters/pxla.py", line 1258, in __call__
    dispatch.check_special(self.name, arrays)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 315, in check_special
    _check_special(name, buf.dtype, buf)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 320, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(step)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/pjit.py", line 1568, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/interpreters/pxla.py", line 1258, in __call__
    dispatch.check_special(self.name, arrays)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 315, in check_special
    _check_special(name, buf.dtype, buf)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 320, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(scan)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/workspace/train.py", line 154, in <module>
    state = jit_step(state, ctrl)
  File "/home/workspace/envs/env.py", line 710, in step
    pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
  File "/usr/local/lib/python3.9/dist-packages/brax/envs/base.py", line 183, in pipeline_step
    return jax.lax.scan(f, pipeline_state, (), self._n_frames)[0]
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/google/jax.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/workspace/train.py", line 154, in <module>
    state = jit_step(state, ctrl)
FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/google/jax.
Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/pjit.py", line 1568, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/interpreters/pxla.py", line 1258, in __call__
    dispatch.check_special(self.name, arrays)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 315, in check_special
    _check_special(name, buf.dtype, buf)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 320, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(step)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/pjit.py", line 1568, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/interpreters/pxla.py", line 1258, in __call__
    dispatch.check_special(self.name, arrays)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 315, in check_special
    _check_special(name, buf.dtype, buf)
  File "/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py", line 320, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(scan)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/workspace/train.py", line 154, in <module>
    state = jit_step(state, ctrl)
  File "/home/workspace/envs/env.py", line 710, in step
    pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)
  File "/usr/local/lib/python3.9/dist-packages/brax/envs/base.py", line 183, in pipeline_step
    return jax.lax.scan(f, pipeline_state, (), self._n_frames)[0]
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/google/jax.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/workspace/train.py", line 154, in <module>
    state = jit_step(state, ctrl)
FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

I'm not sure how the training could work well and then at inference generate NaNs as a NaN value in training would have thrown an error. My model does include a decent number of contacts and two equality constraints that create a loop constraint, but the model appears stable in MuJoCo and during the training.

I do have a work around to fix the issue, which is increasing to 64 bit precision:

jax.config.update('jax_enable_x64', True)

My main concern here is that the training time increases drastically along with the GPU memory required. Training for 1 million steps went from 1min 42s to 3 min 42s (on an RTX 4090) and the GPU memory to allocate went from ~20 GB to ~46 GB. Excluding some contacts allowed me to reduce this to 2 min 56s and back under the 24 GB of memory to continue using this GPU.

My pipeline mirrors the Barkour training and inference pipeline very closely.

Some model details that may help (also very similar to Barkour model): training dt = 0.02 model.opt.timestep = 0.005 integrator = Euler (though I did try the RK4 and it didn't help) eulerdamp = disable iterations = 1 ls_iterations = 5

I am using MuJoCo/MJX = 3.1.6 and Brax = 0.9.4 (though I also tried 0.10.5 and same the same issues).

Is there a reason that I am encountering this behaviour when performing the inference?

Thanks!

erikfrey commented 1 month ago

Just so I understand - you're training in MJX and evaluating the policy in C MuJoCo (presumably via the python bindings) and seeing unstable physics? Is it possible that there's something else different between the training and eval environments, possibly the initial state? Are you hitting some terminating condition that you're ignoring during the eval? What does the video look like leading up to the instability?

Feel free to post a colab.

willthibault commented 1 month ago

I am training in MJX then evaluating the policy in python the same as the this part of the Barkour colab. Similar to the Barkour colab the training and eval environments are identical, including the initial state. I monitor the termination condition when visualizing the policy and it is not terminating, but simply producing the NaN control value. Up to instability (generally one frame), the initial state based on the keyframe I'm using is set and looks correct then it goes to NaN control values.

Just to reiterate, I am using a pipeline that very closely mimics the Barkour colab. I have used this pipeline for many problems and am reasonably certain that it works successfully. In the past when I had NaNs it would occur during training due to an unstable simulation or could be resolved with jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH). In this case, that does not solve the issue as it only occurs when generating the frames for the video. I'm wondering if this happens when jitting the functions for inference.

I'll see if I can put together a colab to reproduce this issue, but it does involve a reasonable train time and I may not able to open source this just yet (hopefully soon though). Are there any additional checks I could perform or logs I can provide?

erikfrey commented 1 month ago

OK, if you really think it's happening somewhere in the inference function, that's a bit suprising to me, but the good news is that's a pretty small surface area to search - really only a few hundred lines of code or so. You can try removing the @jit so you can trace through, or binary search for the nan with jax.debug.print - I bet you can find it that way. Let us know!

willthibault commented 1 month ago

Thanks for the suggestion! I'll spend some time tracking down the error and share what I find.

i1Cps commented 1 month ago

any luck?

willthibault commented 1 month ago

@i1Cps,

I still need to investigate this some more, but I can share what I have figured out so far.

First, make sure that your simulation is stable. Simulations with features like many contacts, unrealistically high control actions and highly constrained systems (ex. the equality constraints creating a loop as mentioned above) can become unstable easily. This was not the case for me.

What produced NaNs:

What did not produce NaNs:

I will work on tracking this down more in the coming weeks, but hopefully this helps!

erikfrey commented 1 month ago

I have also seen scenarios where a 4090 produces unstable physics where an A100 does not, given the exact same MJX environment and python version. I have yet to track down why, but it probably has something to do with matmul precision defaults.