patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.14k stars 142 forks source link

BatchNorm raises "TypeError: Expected a callable value, got inf" #234

Closed ciupakabra closed 1 year ago

ciupakabra commented 2 years ago

I'm getting weird errors with BatchNorm. One example is the code below, where ODE parameters are optimized and the drift is a neural network with some BatchNorm layers. The error thrown is TypeError: Expected a callable value, got inf when vmapping a loss function.

I wasn't able to reduce this to something without diffrax -- double vmaps or vmapping a scan doesn't raise exceptions. I'm guessing this is an equinox issue since without batchnorm (bn=False) there are no problems.

import optax
import jax
import jax.numpy as jnp
import jax.random as jrandom
import jax.nn as jnn
import equinox as eqx
import diffrax as dx
from tqdm import tqdm

def integrate(drift, num_steps, dim, key):

    def f(t, y, args):
        return drift(jnp.concatenate([t[None], y]))

    drift_term = dx.ODETerm(f)
    solver = dx.Euler()
    y0 = jnp.zeros(dim)

    ts = jnp.linspace(0, 1, num_steps + 1)
    saveat = dx.SaveAt(ts=ts)

    sol = dx.diffeqsolve(
        drift_term,
        solver,
        0,
        1,
        1/num_steps,
        y0,
        saveat=saveat,
        max_steps=num_steps + 1,
    )

    return sol.ys

def loss(drift, num_steps, dim, key):
    path = integrate(drift, num_steps, dim, key)
    final = path[-1]
    loss = jnp.sum(final**2)
    return loss

@eqx.filter_value_and_grad
def loss_mean(drift, num_steps, dim, key, batch_size):
    loss_vmapped = jax.vmap(loss, (None, None, None, 0), 0, axis_name="batch")
    key = jrandom.split(key, batch_size)
    return jnp.mean(loss_vmapped(drift, num_steps, dim, key))

class Network(eqx.Module):
    net: eqx.Module

    def __init__(self, in_size, out_size, width, depth, *, key, bn=True):

        keys = jrandom.split(key, depth + 1)
        layers = []
        if depth == 0:
            layers.append(eqx.nn.Linear(in_size, out_size, key=keys[0]))
        else:
            layers.append(eqx.nn.Linear(in_size, width, key=keys[0]))
            if bn:
                layers.append(eqx.experimental.BatchNorm(width, axis_name="batch"))
            for i in range(depth - 1):
                layers.append(eqx.nn.Linear(width, width, key=keys[i + 1]))
                if bn: 
                    layers.append(eqx.experimental.BatchNorm(width, axis_name="batch"))
                layers.append(eqx.nn.Lambda(jnn.relu))
            layers.append(eqx.nn.Linear(width, out_size, key=keys[-1]))

        self.net = eqx.nn.Sequential(layers)

    def __call__(self, x):
        return self.net(x)

if __name__=="__main__":

    key = jrandom.PRNGKey(0)

    init_drift_key, train_key = jrandom.split(key, 2)

    dim = 500

    drift = Network(dim + 1, dim, 300, 2, key=init_drift_key, bn=True)

    optimizer = optax.adamw(1e-4)
    opt_state = optimizer.init(eqx.filter(drift, eqx.is_inexact_array))

    @eqx.filter_jit
    def make_step(drift, num_steps, dim, key, batch_size, opt_state):
        loss, grads = loss_mean(drift, num_steps, dim, key, batch_size)
        updates, opt_state = optimizer.update(
            grads, opt_state, eqx.filter(drift, eqx.is_inexact_array)
        )
        drift = eqx.apply_updates(drift, updates)
        return loss, drift, opt_state

    for step in tqdm(range(100)):
        step_key = jrandom.fold_in(train_key, step)
        loss, drift, opt_state = make_step(
            drift, 10, dim, step_key, 32, opt_state
        )
(env) andrius:/home/andrius/repos/test% python test.py  
  1%|▏                                                                                                                                                                                                                       | 1/100 [00:01<03:14,  1.97s/it]
Traceback (most recent call last):
  File "/home/andrius/repos/test/test.py", line 99, in <module>
    loss, drift, opt_state = make_step(
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 82, in __call__
    return __self._fun_wrapper(False, args, kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 78, in _fun_wrapper
    dynamic_out, static_out = self._cached(dynamic, static)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 622, in cache_miss
    execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy
    return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/linear_util.py", line 303, in memoized_fun
    ans = call(fun, *args)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 30, in fun_wrapped
    out = fun(*args, **kwargs)
  File "/home/andrius/repos/test/test.py", line 89, in make_step
    loss, grads = loss_mean(drift, num_steps, dim, key, batch_size)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 40, in __call__
    return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 1167, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 2656, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/ad.py", line 135, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/ad.py", line 124, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 767, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 37, in fun_value_and_grad
    return __self._fun(_x, *_args, **_kwargs)
  File "/home/andrius/repos/test/test.py", line 43, in loss_mean
    loss_vmapped = jax.vmap(loss, (None, None, None, 0), 0, axis_name="batch")
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 1647, in vmap
    _check_callable(fun)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 181, in _check_callable
    raise TypeError(f"Expected a callable value, got {fun}")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Expected a callable value, got inf

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/andrius/repos/test/test.py", line 99, in <module>
    loss, drift, opt_state = make_step(
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 82, in __call__
    return __self._fun_wrapper(False, args, kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 78, in _fun_wrapper
    dynamic_out, static_out = self._cached(dynamic, static)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 30, in fun_wrapped
    out = fun(*args, **kwargs)
  File "/home/andrius/repos/test/test.py", line 89, in make_step
    loss, grads = loss_mean(drift, num_steps, dim, key, batch_size)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 40, in __call__
    return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 37, in fun_value_and_grad
    return __self._fun(_x, *_args, **_kwargs)
  File "/home/andrius/repos/test/test.py", line 43, in loss_mean
    loss_vmapped = jax.vmap(loss, (None, None, None, 0), 0, axis_name="batch")
TypeError: Expected a callable value, got inf
(env) andrius:/home/andrius/repos/test% pip list
Package           Version
----------------- ---------------------
absl-py           1.3.0
chex              0.1.5
diffrax           0.2.2
dm-tree           0.1.7
equinox           0.9.2
jax               0.3.25
jaxlib            0.3.25+cuda11.cudnn82
jaxtyping         0.2.8
numpy             1.23.5
opt-einsum        3.3.0
optax             0.1.4
pi                0.1.2
pip               22.3.1
scipy             1.9.3
setuptools        65.5.0
toolz             0.12.0
tqdm              4.64.1
typeguard         2.13.3
typing_extensions 4.4.0
wheel             0.37.1
uuirs commented 2 years ago

loss is redefined in here

loss, drift, opt_state = make_step( drift, 10, dim, step_key, 32, opt_state ) rename loss function to loss_fn, it works.

patrick-kidger commented 2 years ago

A @uuirs points out, you've redefined the variable loss from a function to an array.

That said, it's a little odd that you're only hitting this when using BatchNorm. Is BatchNorm triggering a recompilation, somehow?

ciupakabra commented 2 years ago

Yeah it's odd that this happened in such a way and you're right -- I checked and with bn=False it compiles once, whereas with bn=True it compiles twice.

ciupakabra commented 2 years ago

Had some time to debug this tonight. I think it's because the state of BatchNorm is first initialized with a pytree of a different PyTreeDef than what it stores once it actually gets a state, i.e. it's first initialized as an array:

https://github.com/patrick-kidger/equinox/blob/039f6bf69125196dd2357ecf7e082d2f1bcc422c/equinox/experimental/batch_norm.py#L115

but once it calculates the running means and vars it sets it to a tuple of arrays:

https://github.com/patrick-kidger/equinox/blob/039f6bf69125196dd2357ecf7e082d2f1bcc422c/equinox/experimental/batch_norm.py#L175

This changes the PyTreeDef of the StateIndex module, which then changes the hash of PyTreeDef of the whole model. And since the PyTreeDef is used as a static argument somewhere in equinox jitting logic, it gets recompiled.

I'm assuming the fix is a one line change in BatchNorm initialization?

patrick-kidger commented 2 years ago

Aren't these different state indices?

ciupakabra commented 2 years ago

Sorry, you're right, for initialization I meant to reference:

https://github.com/patrick-kidger/equinox/blob/039f6bf69125196dd2357ecf7e082d2f1bcc422c/equinox/experimental/stateful.py#L107

To be fair, even then I don't understand some of the behaviour. A small working example of the bug above is:

import jax
import jax.random as jrandom
import equinox as eqx

@eqx.filter_jit
def fun(bn, inp):
    print("Compiling!")
    return jax.vmap(bn, axis_name="batch")(inp)

def info(bn):
    print(f"bn.state_index._state: {bn.state_index._state}") # prints None all the time
    children, aux = bn.state_index.tree_flatten() # children should be dynamic_field_values and so children[0] should be the value of _state 
    print(f"children[0] of bn.state_index flatten: {children[0]}") # prints a tuple of arrays after 1st call to fun

if __name__=="__main__":
    bn = eqx.experimental.BatchNorm(10, axis_name="batch")
    inp = jrandom.normal(jrandom.PRNGKey(0), (32, 10))

    info(bn)
    fun(bn, inp)
    info(bn)
    fun(bn, inp)
    info(bn)
    fun(bn, inp)

which outputs

bn.state_index._state: None
children[0] of bn.state_index flatten: None
Compiling!
bn.state_index._state: None
children[0] of bn.state_index flatten: (DeviceArray([-0.32395738, -0.21207958, -0.31645954,  0.05969752,
             -0.11307174,  0.0944065 , -0.14875616, -0.05194131,
              0.10097986,  0.25392908], dtype=float32), DeviceArray([0.92587423, 0.95594984, 1.0194211 , 0.84475476, 0.76749337,
             0.77187854, 1.38814   , 0.8497227 , 1.1132355 , 0.86574566],            dtype=float32))
Compiling!
bn.state_index._state: None
children[0] of bn.state_index flatten: (DeviceArray([-0.32395738, -0.21207958, -0.31645954,  0.05969752,
             -0.11307174,  0.0944065 , -0.14875618, -0.05194131,
              0.10097986,  0.25392908], dtype=float32), DeviceArray([0.92587423, 0.95594984, 1.0194211 , 0.84475476, 0.76749337,
             0.77187854, 1.38814   , 0.8497227 , 1.1132355 , 0.86574566],            dtype=float32))

So you can clearly see that at least when accessing _state through tree_flatten of StateIndex, _state changes from None to a 2-tuple of arrays, changing the PyTreeDef of the static args.

I don't understand why bn.state_index._state is constantly outputing None though, I thought this would set it to a new value:

https://github.com/patrick-kidger/equinox/blob/039f6bf69125196dd2357ecf7e082d2f1bcc422c/equinox/experimental/stateful.py#L183

but I guess it gets deleted somewhere in between the two prints by _delete_smuggled_state.

patrick-kidger commented 1 year ago

Closing as eqx.experimental.BatchNorm is now available (in theory without bugs) as eqx.nn.BatchNorm.