google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Gradient of generalized pipeline is nan #387

Closed hackertyper closed 8 months ago

hackertyper commented 10 months ago

Hi,

first of all thanks for your work on brax!

I have noticed an anomaly when trying to get gradients of different pipelines.

When differentiating the pipeline functions, only the positional and spring backend yield a number as gradient. The generalized backend only yiels NaN.

The following code reproduces the issue. It simulates a free falling ball in an otherwise empty environment. The simulation is done for 100 steps. The resulting gradients are 1.0 for spring and positional each but nan for generalized.

import jax
from jax import numpy as jp
from brax.generalized import pipeline as generalized_pipeline
from brax.positional import pipeline as positional_pipeline
from brax.spring import pipeline as spring_pipeline
from brax.io import mjcf

free_fall_xml = """
<mujoco>
  <option gravity="0 0 -9.81" timestep="0.005" density="1.2" viscosity="0.00002" />
  <worldbody>
    <body pos="0 0 3" name="freeball">
      <joint type="free" name="j"/>
      <geom type="sphere" size=".2" mass="1.0" name="g" />
    </body>
  </worldbody>
</mujoco>
"""

# Simulate physics for 100 steps. The function takes an initial x and returns the resulting x, so it can be
# differentiated with respect to x.
def simulation(pipeline, system, init_x):
    init_q = system.init_q.at[0].set(init_x)
    init_qd = jp.zeros(6)
    state = jax.jit(pipeline.init)(system, init_q, init_qd)
    for i in range(100):
        state = jax.jit(pipeline.step)(system, state, None)
    return state.q[0]

if __name__ == "__main__":
    system = mjcf.loads(free_fall_xml)
    for pipeline in [generalized_pipeline, positional_pipeline, spring_pipeline]:
        print(f"{pipeline} results:")
        x = simulation(pipeline, system, 0)
        grad_x = jax.grad(simulation, argnums=2)(pipeline, system, 0.0)
        print(f"{x=}, {grad_x=}")

Output:

<module 'brax.generalized.pipeline'> results:
x=Array(0., dtype=float32), grad_x=Array(nan, dtype=float32, weak_type=True)
<module 'brax.positional.pipeline'> results:
x=Array(0., dtype=float32), grad_x=Array(1., dtype=float32, weak_type=True)
<module 'brax.spring.pipeline'> results:
x=Array(0., dtype=float32), grad_x=Array(1., dtype=float32, weak_type=True)
btaba commented 8 months ago

Hi @hackertyper , thanks for the bug report! If you are able to pinpoint which part of the pipeline returns a NaN with https://jax.readthedocs.io/en/latest/debugging/flags.html that would be helpful to know where the grad is hitting a snag. As of yet, we haven't paid much attention to gradients in the generalized implementation

btaba commented 8 months ago

this should be fixed in 1630403