google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Getting gradients using the generalized pipeline results in Nans #412

Closed chan-mander closed 8 months ago

chan-mander commented 8 months ago

Hello,

Thank you for your amazing work with brax, I am excited to use this physics engine for simulations.

I have been running into a particular issue that I don't understand. Below is a simple example, that will reproduce the issue I am seeing.

import jax
import jax.numpy as jnp
from brax.io import mjcf
import brax
from brax.generalized import pipeline
# from brax.positional import pipeline

@jax.jit
def loss(weight, system, state):
    x = jnp.array([state.x.pos[0][0], state.qd[0]])
    force = weight @ x.transpose()
    state = pipeline.step(system, state, force)
    loss_value = state.x.pos[0][0]**2
    return loss_value, state

def main():
    key = jax.random.PRNGKey(seed=0)
    key, subkey = jax.random.split(key, 2)
    weight = jax.random.normal(subkey, (1, 2))

    system: brax.System = mjcf.loads(
                                """<mujoco model="ice skate">
                                    <option timestep="0.001"/>
                                    <worldbody>
                                        <body name="box" pos="10 0 0.5">
                                            <joint limited="false" type="slide" axis="1 0 0" name="rootx"/> 
                                            <joint limited="false" type="slide" axis="0 0 1" name="rootz"/> 
                                            <geom name="box_geom" type="box" size="0.5 0.5 0.5" mass="1" condim="3" friction="0.3 0"/>
                                        </body>
                                        <geom name="floor" pos="0 0 0" size="5 5 5" type="plane" condim="3" friction="0.3 0"/>
                                    </worldbody>                                
                                    <actuator>
                                        <motor ctrllimited="true" ctrlrange="-100 100" gear="1" joint="rootx"/>
                                    </actuator>
                                    </mujoco>""")

    qd = jnp.zeros(system.qd_size())
    state = jax.jit(pipeline.init)(system, system.init_q, qd)

    simTime = 10
    timeSteps = int(simTime / 0.001)

    print("_________________________________________")
    print("Intial State Pos: ", state.x.pos[0][0])
    for i in range(timeSteps):
        force = jnp.array([-10.0])
        state = jax.jit(pipeline.step)(system, state, force)

    print("Final State Pos: ", state.x.pos[0][0])
    print("_________________________________________")

    grad, _ = jax.jit(jax.grad(loss, has_aux=True))(weight, system, state)
    print("Gradients of layer 0 weights:\n", grad)

if __name__ == "__main__":
    main()

when I use brax.generalized and try to get gradients from my loss function I end up getting Nans, but the box object's final position is where I expect it to be when friction and a constant force are being applied to it. When, I use brax.positional I am able to get gradients but the final position of the box object is not where it should be. Ideally, I would like to get gradients using the brax.generalized. Any help or insight as to why I am seeing this issue would be very helpful.

Thank you.

chan-mander commented 8 months ago

I have also used the JAX debug flags to see where the Nans are occurring and this was the output:

Traceback (most recent call last):
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1148, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1233, in __call__
    dispatch.check_special(self.name, arrays)
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 436, in check_special
    _check_special(name, buf.dtype, buf)
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 441, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(loss)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1148, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1233, in __call__
    dispatch.check_special(self.name, arrays)
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 436, in check_special
    _check_special(name, buf.dtype, buf)
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 441, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jit(loss)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 218, in <module>
    main()
  File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 213, in main
    grad, _ = jax.jit(jax.grad(loss, has_aux=True))(weight, system, state)
  File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 170, in loss
    state = pipeline.step(system, state, force)
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/brax/generalized/pipeline.py", line 79, in step
    state = mass.matrix_inv(sys, state, sys.matrix_inv_iterations)
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/brax/generalized/mass.py", line 101, in matrix_inv
    mx_inv = math.inv_approximate(mx, mx_inv, num_iter)
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/brax/math.py", line 313, in inv_approximate
    (a_inv, _, _), _ = jax.lax.scan(body_fn, (a_inv, r0, 1.0), None, num_iter)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(scan)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 218, in <module>
    main()
  File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 213, in main
    grad, _ = jax.jit(jax.grad(loss, has_aux=True))(weight, system, state)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 253, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
                                                 ^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 166, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 2596, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 389, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 821, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1209, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1192, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1152, in _pjit_call_impl_python
    _ = core.jaxpr_as_fun(jaxpr)(*args)  # may raise, not return
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 235, in jaxpr_as_fun
    return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 454, in eval_jaxpr
    ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 2596, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 389, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 821, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1209, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1192, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1152, in _pjit_call_impl_python
    _ = core.jaxpr_as_fun(jaxpr)(*args)  # may raise, not return
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 235, in jaxpr_as_fun
    return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 454, in eval_jaxpr
    ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py", line 1097, in scan_bind
    return core.AxisPrimitive.bind(scan_p, *args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 2596, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 389, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/core.py", line 821, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 143, in apply_primitive
    return compiled_fun(*args)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1233, in __call__
    dispatch.check_special(self.name, arrays)
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 436, in check_special
    _check_special(name, buf.dtype, buf)
  File "/home/chandikasilva/anaconda3/envs/envbrax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 441, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
jax._src.traceback_util.UnfilteredStackTrace: FloatingPointError: invalid value (nan) encountered in jit(scan)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 218, in <module>
    main()
  File "/home/chandikasilva/WS/EAGER/scratch_tool/simple_test_example.py", line 213, in main
    grad, _ = jax.jit(jax.grad(loss, has_aux=True))(weight, system, state)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in jit(scan)
btaba commented 8 months ago

Thanks for the report @chan-mander , this should be fixed in 1630403 although this results in a slight performance hit for generalized