Closed chan-mander closed 8 months ago
I have also used the JAX debug flags to see where the Nans are occurring and this was the output:
Traceback (most recent call last):
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1148, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1233, in __call__
dispatch.check_special(self.name, arrays)
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 436, in check_special
_check_special(name, buf.dtype, buf)
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 441, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(loss)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1148, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1233, in __call__
dispatch.check_special(self.name, arrays)
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 436, in check_special
_check_special(name, buf.dtype, buf)
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 441, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(loss)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 218, in <module>
main()
File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 213, in main
grad, _ = jax.jit(jax.grad(loss, has_aux=True))(weight, system, state)
File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 170, in loss
state = pipeline.step(system, state, force)
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/brax/generalized/pipeline.py", line 79, in step
state = mass.matrix_inv(sys, state, sys.matrix_inv_iterations)
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/brax/generalized/mass.py", line 101, in matrix_inv
mx_inv = math.inv_approximate(mx, mx_inv, num_iter)
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/brax/math.py", line 313, in inv_approximate
(a_inv, _, _), _ = jax.lax.scan(body_fn, (a_inv, r0, 1.0), None, num_iter)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(scan)
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 218, in <module>
main()
File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 213, in main
grad, _ = jax.jit(jax.grad(loss, has_aux=True))(weight, system, state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 253, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 166, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 2596, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 389, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 821, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1209, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1192, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1152, in _pjit_call_impl_python
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 235, in jaxpr_as_fun
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 454, in eval_jaxpr
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 2596, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 389, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 821, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1209, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1192, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1152, in _pjit_call_impl_python
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 235, in jaxpr_as_fun
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 454, in eval_jaxpr
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py", line 1097, in scan_bind
return core.AxisPrimitive.bind(scan_p, *args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 2596, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 389, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 821, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 143, in apply_primitive
return compiled_fun(*args)
^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1233, in __call__
dispatch.check_special(self.name, arrays)
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 436, in check_special
_check_special(name, buf.dtype, buf)
File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 441, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
jax._src.traceback_util.UnfilteredStackTrace: FloatingPointError: invalid value (nan) encountered in jit(scan)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 218, in <module>
main()
File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 213, in main
grad, _ = jax.jit(jax.grad(loss, has_aux=True))(weight, system, state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in jit(scan)
Thanks for the report @chan-mander , this should be fixed in 1630403 although this results in a slight performance hit for generalized
Hello,
Thank you for your amazing work with brax, I am excited to use this physics engine for simulations.
I have been running into a particular issue that I don't understand. Below is a simple example, that will reproduce the issue I am seeing.
when I use
brax.generalized
and try to get gradients from my loss function I end up gettingNans
, but the box object's final position is where I expect it to be when friction and a constant force are being applied to it. When, I usebrax.positional
I am able to get gradients but the final position of the box object is not where it should be. Ideally, I would like to get gradients using thebrax.generalized
. Any help or insight as to why I am seeing this issue would be very helpful.Thank you.