patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.08k stars 138 forks source link

Implementing Batch Normalization #45

Closed marcelroed closed 2 years ago

marcelroed commented 2 years ago

In Flax, Batch Normalization is a bit finicky since each call to apply requires marking batch_stats as mutable and updating the batch_stats afterward.

bn = flax.linen.BatchNorm(use_running_average=True)

x = jnp.arange(24).reshape(3, 6)

vars = bn.init(random.PRNGKey(0), x)

# Mark the batch stats as mutable so we can update them in the variable dictionary
x_normed, mutated_vars = bn.apply(vars, x, mutable=['batch_stats'])

vars = {**vars, **mutated_vars}  # Update the variables with our diff

x_normed2, mutated_vars2 = bn.apply(vars, x, mutable=['batch_stats'])

How could this be implemented as a Module in Equinox? I'm happy to submit an implementation given some guidance.

patrick-kidger commented 2 years ago

Right. So this is something I've been avoiding implementing. It is possible using Equinox but I'm honestly not sure it belongs in the library itself, as it requires changing how you use things slightly.

First of all, an (untested) implementation looks something like this.

def pvar(x, axis_name):
    mean = lax.pmean(x, axis_name)
    var = lax.pmean((x - mean)**2, axis_name)
    return var

class BatchNorm(eqx.Module):
    running_mean: Optional[jnp.ndarray]
    running_var: Optional[jnp.ndarray]
    momentum: float
    axis_name: str
    eps: float
    update_stats: bool

    def __init__(self, momentum=0.99, axis_name="batch", eps=1e-5, update_stats=True, **kwargs):
        super().__init__(**kwargs)
        self.running_mean = None
        self.running_var = None
        self.momentum = momentum
        self.axis_name = axis_name
        self.eps = eps
        self.update_stats = update_stats

   def __call__(self, x, update_stats=None):
        if update_stats is None:
            update_stats = self.update_stats
        mean = lax.pmean(x, axis=self.axis_name)
        var = pvar(x, axis=self.axis_name)
        if self.running_mean is None:
            running_mean = mean
            running_var = var
        else:
            running_mean = self.running_mean
            running_var = self.running_var
        running_mean = (1 - self.momentum) * mean + self.momentum * lax.stop_gradient(running_mean)
        running_var = (1 - self.momentum) * var + self.momentum * lax.stop_gradient(running_var)
        x = (x - running_mean) / jnp.sqrt(running_var + self.eps)
        if self.update_stats:
            object.__setattr__(self, "running_mean", running_mean)
            object.__setattr__(self, "running_var", running_var)
        return x

The really weird thing here is the use of object.__setattr__. This is essentially equivalent to self.running_mean = running_mean; self.running_var = running_var, but hackily avoids the error message you get about trying to mutate something. (eqx.Modules are supposed to be immutable after __init__.)

Assuming you've wrapped everything up into a JIT wrapper then this means you will need to pass out anything you've mutated. That is, your loss computation needs to change from something looking like

@jax.jit
@jax.grad
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y)**2)

to something looking like

@jax.jit
@jax.grad
def loss(model, x, y):
    pred_y = jax.vmap(model, axis_name="batch")(x)
    return jnp.mean((y - pred_y)**2), model

and you'll need to use the updated version of model -- which has the updated batch statistics on it.

Forgetting to use the updated version of model is a potential footgun, which is why I'm a bit concerned about having it in the library. (Flax gets around this by just having footguns all over the place, so having another new one isn't a surprise...)

Other minor things to note.

patrick-kidger commented 2 years ago

Oh, some other remarks:

  1. If you ever have nested JITs or anything like that, then the rule of "pass your model out of JIT" starts becoming easier to miss and harder to follow.
  2. It is possible to imagine other implementations that try and work around this passing-model-out for you, but they start introducing deeper kinds of magic.

BatchNorm is now much less popular than it used to be -- it largely seems to have been supplanted by LayerNorm etc. -- and moreover it's the only case I'm aware of in which we want to do some weird mutable-buffer stuff like this. So honestly I've been avoiding adding BatchNorm to the library because it takes a little finesse to use without surprises, and one of the really nice things about Equinox so far (compared to Flax etc.) has been its lack of footguns and surprises.

marcelroed commented 2 years ago

Thank you for your detailed response, @patrick-kidger!

I think you're right that BatchNorm doesn't mesh well with the functional way of doing things, and that LayerNorm does the trick most of the time. It would still be nice to be able to recreate existing models with Equinox, however, but as you say it's very prone to error.

Perhaps it's possible to add a decorator marking mutability, in my changes named eqx.mutates:

def pvar(x, axis_name):
    mean = lax.pmean(x, axis_name)
    var = lax.pmean((x - mean)**2, axis_name)
    return var

class BatchNorm(eqx.Module):
    running_mean: Optional[jnp.ndarray]
    running_var: Optional[jnp.ndarray]
    momentum: float
    axis_name: str
    eps: float
    update_stats: bool

    def __init__(self, momentum=0.99, axis_name="batch", eps=1e-5, update_stats=True, **kwargs):
        super().__init__(**kwargs)
        self.running_mean = None
        self.running_var = None
        self.momentum = momentum
        self.axis_name = axis_name
        self.eps = eps
        self.update_stats = update_stats

    # Marks that this function returns a mutated model.
    # Will return a PyTree with only the updated variables alongside the result.
    @eqx.mutates
    def __call__(self, x, update_stats=None):
        if update_stats is None:
            update_stats = self.update_stats
        mean = lax.pmean(x, axis=self.axis_name)
        var = pvar(x, axis=self.axis_name)
        if self.running_mean is None:
            running_mean = mean
            running_var = var
        else:
            running_mean = self.running_mean
            running_var = self.running_var
        running_mean = (1 - self.momentum) * mean + self.momentum * lax.stop_gradient(running_mean)
        running_var = (1 - self.momentum) * var + self.momentum * lax.stop_gradient(running_var)
        x = (x - running_mean) / jnp.sqrt(running_var + self.eps)
        if self.update_stats:
            self.mutate('running_mean', running_mean)  # Raises an error if the method isn't decorated
            self.mutate('running_var', running_var)
        return x

I'm thinking that calling a mutates method would return the results along with updates if not called from another mutates context, in which case it'll register mutations to the parent instance. Finally, the updates that are returned can be merged with the model PyTree. I know this behavior is possible with a reasonable amount of magic in normal Python, but I'm not entirely sure what it might look like in Jax with JIT.

There are some other applications where having some mutable state is useful. For example, I've seen counters used for modules that can be called a variable amount of times when used in an adaptive step size ODE solver and modules that need to memorize what actions have been taken in reinforcement learning. If I'm not mistaken pretty much anything that needs to change with anything but gradients needs this kind of mutability, so I think implementing some solution would make Equinox a lot more versatile.

patrick-kidger commented 2 years ago

So I think the mutates decorator you're suggesting is essentially equivalent to my first proposal. I don't think there's any benefit to using a mutates decorator that delays the mutation until after the function has been traced. (And either way you still need to pass the mutated model out of the JIT'd region.)

I agree that changing with anything except gradients needs this kind of mutability. I'd love to have a neat solution to this in Equinox but I don't think there's a generic approach that's compatible with JAX's functional paradigm. The best ideas I've got in this space involve either creating custom interpreters or using XLA's infeed and outfeed (i.e. jax.experimental.host_callback.call), and those are both relatively deep magic.

At least for adaptive step size ODE solvers, any decent solver should report the number of steps it took to you as a statistic. I think that kind of thing is true of most use-cases: there's usually some other way around it.

patrick-kidger commented 2 years ago

Here's a quick first-pass at the XLA infeed/outfeed approach.

import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.experimental.host_callback as hcb

cache = {}
shapes = {}
index = 0

def genindex():
    global index
    index += 1
    return index

def _get_state(i):
    return cache[i.item()]

def _set_state(i__v):
    i, v = i__v
    cache[i.item()] = v

def get_state(i):
    try:
        shape = shapes[i]
    except KeyError as e:
        raise RuntimeError from e
    return hcb.call(_get_state, i, result_shape=shape)

def set_state(i, v):
    try:
        shape = shapes[i]
    except KeyError:
        shapes[i] = jax.ShapedArray(shape=v.shape, dtype=v.dtype)
    else:
        if shape.shape != v.shape or shape.dtype != v.dtype:
            raise ValueError
    v = lax.stop_gradient(v)
    hcb.call(_set_state, (i, v))

class BatchNorm(eqx.Module):
    mean_index: int = eqx.static_field()
    momentum: float

    def __init__(self):
        self.mean_index = genindex()
        set_state(self.mean_index, jnp.zeros(()))
        self.momentum = 0.99

    def __call__(self, x):
        running_mean = (1 - self.momentum) * jnp.mean(x) + self.momentum * get_state(self.mean_index)
        set_state(self.mean_index, running_mean)
        return x - running_mean

model = BatchNorm()

@jax.jit
def forward(x):
    return model(x)

print(forward(jnp.array(1.)))
print(forward(jnp.array(1.)))
print(forward(jnp.array(1.)))  # Different results each time!

As you can see this is an essentially toy implementation for readability: it still only uses a default value for the running mean (zero) rather than the first evaluation of the model; it only subtracts off the mean and doesn't normalise by variance; it doesn't have a post-normalisation affine transformation. But those are all just non-technical implementation details, of course.

In practice this does still has a few technical details that need sorting out:

  1. support for vmap;
  2. removing old elements from cache and shapes.

But fundamentally I think this approach works.

CC @FedericoV as I recall you needed something similar -- you said you wanted to cache hints for optimisation problems. I recall mentioning the above as a solution to you; Idk if you ever implemented anything like it.

patrick-kidger commented 2 years ago

Well, I got successfully nerd-sniped into spending my weekend implementing this. (Mostly the new "stateful" technology that makes this possible.)

equinox.experimental.BatchNorm now exists. Happy hacking.

marcelroed commented 2 years ago

Wow, amazing work, Patrick! This is brilliant! I'm definitely using Equinox for my project. I'm reading your implementation to understand a bit better how exactly hcb.call allows for this mutability.

marcelroed commented 2 years ago

Sorry to continue this discussion on a closed issue, but I can't get the tests to succeed on my machine (M1 Mac). It seems jax.experimental.host_callback.call causes a crash instead of raising an error when things go wrong.

Running this test https://github.com/patrick-kidger/equinox/blob/fcf6dadab0a3b99e9bcdc37e67073e097115cbd8/tests/test_stateful.py#L52-L56 results in a crash on this line https://github.com/patrick-kidger/equinox/blob/fcf6dadab0a3b99e9bcdc37e67073e097115cbd8/equinox/experimental/stateful.py#L139

When running in pytest the error output is blank with a pytest-specific stack trace, but running the test directly shows that Jax isn't catching the error from the pybind11 exception.

/opt/homebrew/Caskroom/miniforge/base/envs/equinox/bin/python /Users/marcel/git/equinox/tests/test_stateful.py
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
ERROR:absl:Outside call <function _call.<locals>.<lambda> at 0x134476710> threw exception 'Cannot get state before it has been set'.
libc++abi: terminating with uncaught exception of type pybind11::error_already_set: KeyError: 'Cannot get state before it has been set'

At:
  /Users/marcel/git/equinox/equinox/experimental/stateful.py(95): _get_state_hcb
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(719): <lambda>
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(1203): _outside_call_run_callback
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(1067): wrapped_callback
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/_src/dispatch.py(444): _execute_compiled
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/_src/dispatch.py(94): apply_primitive
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(960): _outside_call_impl
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/core.py(611): process_primitive
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/core.py(289): bind_with_trace
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/core.py(286): bind
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(739): _call
  /opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(689): call
  /Users/marcel/git/equinox/equinox/experimental/stateful.py(110): _get_state
  /Users/marcel/git/equinox/equinox/experimental/stateful.py(140): get_state
  /Users/marcel/git/equinox/tests/test_stateful.py(56): test_no_set
  /Users/marcel/git/equinox/tests/test_stateful.py(164): <module>

Even when wrapping in try/except for KeyError this crashes the program.

I assume this is an issue with Jax/jaxlib on M1?

patrick-kidger commented 2 years ago

Yeah, this is a known issue with JAX -- namely, that host_callback.call handles errors differently depending on OS, device, or phase of the moon.

See also https://github.com/google/jax/issues/9457

I don't think there's much that can be done about this one from the point of view of Equinox I'm afraid.