patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.05k stars 136 forks source link

Getting NaN for gradients when computing the loss with jit compile #523

Closed chan-mander closed 11 months ago

chan-mander commented 12 months ago

Hello,

I am fairly new to JAX and Equinox but I am running into a issue that I don't quite understand. I have the following loss function:

`@partial(jax.jit, static_argnames=('timeSteps')) def loss_fn(model: eqx.Module, system: brax.System, state: brax.State, timeSteps: int = 2000):

def step(i:int, carry: tuple):
    system_s, state_s, model_s, positions_s, velocities_s, forces_s = carry

    positions_s = positions_s.at[i,:].add(state_s.x.pos[0])
    velocities_s = velocities_s.at[i,:].add(state_s.qd)

    x = jnp.array([state_s.x.pos[0][0], state_s.qd[0]])
    force = model_s(x.transpose())
    forces_s = forces_s.at[i].add(force)

    state_s = pipeline.step(system_s, state_s, force)

    return (system_s, state_s, model_s, positions_s, velocities_s, forces_s)

positions = jnp.zeros((timeSteps, 3), dtype=jnp.float32)
velocities = jnp.zeros((timeSteps, 3), dtype=jnp.float32)
forces = jnp.zeros((timeSteps, 1), dtype=jnp.float32)

carry = (system, state, model, positions, velocities, forces)

system, state, _, positions, velocities, forces = jax.lax.fori_loop(0,timeSteps,step,carry)

targetposition = jnp.zeros((timeSteps, 1))
targetVelocities = jnp.zeros((timeSteps, 1))

loss_value = jnp.sum(1.0*jnp.abs(positions[:,0]-targetposition) + 1.0*jnp.abs(velocities[:,0]-targetVelocities)) / timeSteps

return loss_value, (positions, velocities, forces, state)`

I try to get the sensitivity of the loss with respect to the model params by running the following line;

l, grad = eqx.filter_value_and_grad(loss_fn, has_aux=True)(model, system, newState, timeSteps)

When I run this with Jax jit compile I get NaN for the gradients, but when disable the jit compile I am getting values for gradients. Can anyone help me shed some light as to why this is happening, and how I can compute gradients with the jit compiler?

lockwo commented 12 months ago

It might be helpful to have a minimum reproducible example. I see some brax usage as well so giving a short snippet of code that can be copied and pasted into colab and work to show the error is helpful when debugging.

Also, as a minor note, you can get better syntax highlighting and formatting by doing everything in a "python newline code newline"

E.g.

@partial(jax.jit, static_argnames=('timeSteps'))
def loss_fn(model: eqx.Module, system: brax.System, state: brax.State, timeSteps: int = 2000):

  def step(i:int, carry: tuple):
      system_s, state_s, model_s, positions_s, velocities_s, forces_s = carry

      positions_s = positions_s.at[i,:].add(state_s.x.pos[0])
      velocities_s = velocities_s.at[i,:].add(state_s.qd)

      x = jnp.array([state_s.x.pos[0][0], state_s.qd[0]])
      force = model_s(x.transpose())
      forces_s = forces_s.at[i].add(force)
      state_s = pipeline.step(system_s, state_s, force)

      return (system_s, state_s, model_s, positions_s, velocities_s, forces_s)

  positions = jnp.zeros((timeSteps, 3), dtype=jnp.float32)
  velocities = jnp.zeros((timeSteps, 3), dtype=jnp.float32)
  forces = jnp.zeros((timeSteps, 1), dtype=jnp.float32)

  carry = (system, state, model, positions, velocities, forces)
  system, state, _, positions, velocities, forces = jax.lax.fori_loop(0,timeSteps,step,carry)
  targetposition = jnp.zeros((timeSteps, 1))
  targetVelocities = jnp.zeros((timeSteps, 1))
  loss_value = jnp.sum(1.0*jnp.abs(positions[:,0]-targetposition) + 1.0*jnp.abs(velocities[:,0]-targetVelocities)) / timeSteps

  return loss_value, (positions, velocities, forces, state)
patrick-kidger commented 12 months ago

See also https://docs.kidger.site/equinox/api/debug for advice on common sources of NaNs, and some of the tools JAX provides to debug them,

chan-mander commented 11 months ago

Hello,

Thank you for the link, I am currently looking through the debug tools and trying to use them. I am still having this issue, here is a minimum reproducible example, any insights you may have would be helpful:

import jax
import jax.numpy as jnp
from jax.config import config
import brax
from brax.io import mjcf
from brax.generalized import pipeline
import equinox as eqx
import optax
from functools import partial

class SimpleModel(eqx.Module):
    layers: list          # list of the layers in the NN
    extra_bias: jax.Array # extra bias on the output

    def __init__(self, key):
        key1, key2, key3 = jax.random.split(key, 3)
        # layers contain trainable parameters
        self.layers = [
            eqx.nn.Linear(2, 4, key=key1),
            eqx.nn.Linear(4, 4, key=key2),
            eqx.nn.Linear(4, 1, key=key3)
        ]
        # extra_bias is also a trainable parameter
        self.extra_bias = jax.numpy.ones(1)

    @jax.jit
    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.elu(layer(x))
        return self.layers[-1](x) + self.extra_bias

@jax.jit
def loss(model: SimpleModel, system: brax.System, state: brax.State):

    goal_position_x = 0.0
    x = jnp.array([state.x.pos[0][0], state.qd[0]])
    force = model(x.transpose())
    state = pipeline.step(system, state, force)

    loss_value = 1.0*(state.x.pos[0][0] - goal_position_x)**2

    return loss_value, state

def main():

    key = jax.random.PRNGKey(seed=0)
    key, subkey = jax.random.split(key, 2)
    model = SimpleModel(subkey)

    lr = 0.001
    opt = optax.adam(lr)
    opt_state = opt.init(model)

    system: brax.System = mjcf.loads(
                                    """<mujoco>
                                            <option timestep="0.001"/>
                                            <worldbody>
                                                <body pos="{x} {y} {z}">
                                                    <joint name="slider" type="slide" axis="1 0 0" limited="false"/>
                                                    <geom size="0.5 0.5 0.5" type="box" mass="1"/>
                                                </body>
                                                <geom name="floor" size="40 40 40" type="plane"/>
                                            </worldbody>
                                            <actuator>
                                                <motor ctrllimited="true" ctrlrange="-10 10" gear="10"  joint="slider" name="slide"/>
                                            </actuator>
                                        </mujoco>""".format(x=1.0, y=0.0, z=0.5))

    qd = jnp.array([0.0])
    state = jax.jit(pipeline.init)(system, system.init_q, qd)

    for i in range(30):

        l, grad = eqx.filter_value_and_grad(loss, has_aux=True)(model, system, state)

        print("Batch: ", i+1, " Loss: ", l[0])
        print("Gradients of layer 0 weights:\n", grad.layers[0].weight)

        state = l[1]
        updates, opt_state = opt.update(grad, opt_state)
        model = optax.apply_updates(model, updates)

if __name__ == "__main__":
    main()
lockwo commented 11 months ago

Seemed to be some issue with using jax.jit instead of filter jit. This works for me:

import jax
import jax.numpy as jnp
from jax.config import config
import brax
from brax.io import mjcf
from brax.generalized import pipeline
import equinox as eqx
import optax
from functools import partial

class SimpleModel(eqx.Module):
    layers: list          # list of the layers in the NN
    extra_bias: jax.Array # extra bias on the output

    def __init__(self, key):
        key1, key2, key3 = jax.random.split(key, 3)
        # layers contain trainable parameters
        self.layers = [
            eqx.nn.Linear(2, 4, key=key1),
            eqx.nn.Linear(4, 4, key=key2),
            eqx.nn.Linear(4, 1, key=key3)
        ]
        # extra_bias is also a trainable parameter
        self.extra_bias = jax.numpy.ones(1)

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.elu(layer(x))
        return self.layers[-1](x) + self.extra_bias

def loss(model: SimpleModel, system: brax.System, state: brax.State):

    goal_position_x = 0.0
    x = jnp.array([state.x.pos[0][0], state.qd[0]])
    force = model(x.transpose())
    state = pipeline.step(system, state, force)

    loss_value = 1.0*(state.x.pos[0][0] - goal_position_x)**2

    return loss_value, state

def main():

    key = jax.random.PRNGKey(seed=0)
    key, subkey = jax.random.split(key, 2)
    model = SimpleModel(subkey)

    lr = 0.001
    opt = optax.adam(lr)
    opt_state = opt.init(model)

    system: brax.System = mjcf.loads(
                                    """<mujoco>
                                            <option timestep="0.001"/>
                                            <worldbody>
                                                <body pos="{x} {y} {z}">
                                                    <joint name="slider" type="slide" axis="1 0 0" limited="false"/>
                                                    <geom size="0.5 0.5 0.5" type="box" mass="1"/>
                                                </body>
                                                <geom name="floor" size="40 40 40" type="plane"/>
                                            </worldbody>
                                            <actuator>
                                                <motor ctrllimited="true" ctrlrange="-10 10" gear="10"  joint="slider" name="slide"/>
                                            </actuator>
                                        </mujoco>""".format(x=1.0, y=0.0, z=0.5))

    qd = jnp.array([0.0])
    state = jax.jit(pipeline.init)(system, system.init_q, qd)

    for i in range(5):
        l, grad = eqx.filter_jit(eqx.filter_value_and_grad(loss, has_aux=True))(model, system, state)

        print("Batch: ", i+1, " Loss: ", l[0])
        print("Gradients of layer 0 weights:\n", grad.layers[0].weight)

        state = l[1]
        updates, opt_state = opt.update(grad, opt_state)
        model = optax.apply_updates(model, updates)

if __name__ == "__main__":
    main()
Batch:  1  Loss:  1.0000165
Gradients of layer 0 weights:
 [[-8.9743912e-07 -0.0000000e+00]
 [-3.2213477e-06 -0.0000000e+00]
 [-3.2706109e-07 -0.0000000e+00]
 [ 9.9661854e-08  0.0000000e+00]]
Batch:  2  Loss:  1.0000491
Gradients of layer 0 weights:
 [[-9.2726850e-07 -7.6194153e-09]
 [-3.2008952e-06 -2.6301928e-08]
 [-3.6686370e-07 -3.0145388e-09]
 [ 1.3357227e-07  1.0975706e-09]]
Batch:  3  Loss:  1.000098
Gradients of layer 0 weights:
 [[-9.5715563e-07 -1.5666025e-08]
 [-3.1808418e-06 -5.2061694e-08]
 [-4.0678665e-07 -6.6579871e-09]
 [ 1.6731423e-07  2.7384772e-09]]
Batch:  4  Loss:  1.0001631
Gradients of layer 0 weights:
 [[-9.8711382e-07 -2.4135012e-08]
 [-3.1612092e-06 -7.7291816e-08]
 [-4.4685981e-07 -1.0925758e-08]
 [ 2.0088858e-07  4.9117412e-09]]
Batch:  5  Loss:  1.0002439
Gradients of layer 0 weights:
 [[-1.0171552e-06 -3.3021848e-08]
 [-3.1420084e-06 -1.0200500e-07]
 [-4.8710831e-07 -1.5813924e-08]
 [ 2.3429754e-07  7.6064479e-09]]
patrick-kidger commented 11 months ago

Notably, an important thing that @lockwo has done is to put the JIT outside the grad, not inside. This is much more efficient! Ideally, the opt.update and apply_updates lines should also be put inside JIT, see e.g. this RNN example and its make_step function.

That said this should just be an efficiency thing, not a correctness thing.

patrick-kidger commented 11 months ago

Okay, I'm not too sure what's going on, but here's a smaller MWE that doesn't use Equinox or Optax. I'd suggest raising this with probably the Brax authors.

# This gets NaN gradients. Removing the `jax.jit` on line 10, or removing the aux output
# from `jax.grad`, results in non-NaN gradients.

import jax
import jax.numpy as jnp
from brax.io import mjcf
from brax.generalized import pipeline

@jax.jit
def loss(weight, system, state):
    x = jnp.array([state.x.pos[0][0], state.qd[0]])
    force = weight @ x.transpose()
    state = pipeline.step(system, state, force)
    loss_value = state.x.pos[0][0]**2
    return loss_value, state

def main():
    key = jax.random.PRNGKey(seed=0)
    key, subkey = jax.random.split(key, 2)
    weight = jax.random.normal(subkey, (1, 2))

    system = mjcf.loads(
        """<mujoco>
                <option timestep="0.001"/>
                <worldbody>
                    <body pos="{x} {y} {z}">
                        <joint name="slider" type="slide" axis="1 0 0" limited="false"/>
                        <geom size="0.5 0.5 0.5" type="box" mass="1"/>
                    </body>
                    <geom name="floor" size="40 40 40" type="plane"/>
                </worldbody>
                <actuator>
                    <motor ctrllimited="true" ctrlrange="-10 10" gear="10"  joint="slider" name="slide"/>
                </actuator>
            </mujoco>""".format(x=1.0, y=0.0, z=0.5))

    qd = jnp.array([0.0])
    state = jax.jit(pipeline.init)(system, system.init_q, qd)
    grad, _ = jax.jit(jax.grad(loss, has_aux=True))(weight, system, state)
    print("Gradients of layer 0 weights:\n", grad)

if __name__ == "__main__":
    main()
chan-mander commented 11 months ago

I was able to solve the issue. It seems that the problem was in how I was initializing the initial state with Brax. Thank you for your help and input.