Closed marcelroed closed 2 years ago
Right. So this is something I've been avoiding implementing. It is possible using Equinox but I'm honestly not sure it belongs in the library itself, as it requires changing how you use things slightly.
First of all, an (untested) implementation looks something like this.
def pvar(x, axis_name):
mean = lax.pmean(x, axis_name)
var = lax.pmean((x - mean)**2, axis_name)
return var
class BatchNorm(eqx.Module):
running_mean: Optional[jnp.ndarray]
running_var: Optional[jnp.ndarray]
momentum: float
axis_name: str
eps: float
update_stats: bool
def __init__(self, momentum=0.99, axis_name="batch", eps=1e-5, update_stats=True, **kwargs):
super().__init__(**kwargs)
self.running_mean = None
self.running_var = None
self.momentum = momentum
self.axis_name = axis_name
self.eps = eps
self.update_stats = update_stats
def __call__(self, x, update_stats=None):
if update_stats is None:
update_stats = self.update_stats
mean = lax.pmean(x, axis=self.axis_name)
var = pvar(x, axis=self.axis_name)
if self.running_mean is None:
running_mean = mean
running_var = var
else:
running_mean = self.running_mean
running_var = self.running_var
running_mean = (1 - self.momentum) * mean + self.momentum * lax.stop_gradient(running_mean)
running_var = (1 - self.momentum) * var + self.momentum * lax.stop_gradient(running_var)
x = (x - running_mean) / jnp.sqrt(running_var + self.eps)
if self.update_stats:
object.__setattr__(self, "running_mean", running_mean)
object.__setattr__(self, "running_var", running_var)
return x
The really weird thing here is the use of object.__setattr__
. This is essentially equivalent to self.running_mean = running_mean; self.running_var = running_var
, but hackily avoids the error message you get about trying to mutate something. (eqx.Module
s are supposed to be immutable after __init__
.)
Assuming you've wrapped everything up into a JIT wrapper then this means you will need to pass out anything you've mutated. That is, your loss computation needs to change from something looking like
@jax.jit
@jax.grad
def loss(model, x, y):
pred_y = jax.vmap(model)(x)
return jnp.mean((y - pred_y)**2)
to something looking like
@jax.jit
@jax.grad
def loss(model, x, y):
pred_y = jax.vmap(model, axis_name="batch")(x)
return jnp.mean((y - pred_y)**2), model
and you'll need to use the updated version of model
-- which has the updated batch statistics on it.
Forgetting to use the updated version of model is a potential footgun, which is why I'm a bit concerned about having it in the library. (Flax gets around this by just having footguns all over the place, so having another new one isn't a surprise...)
Other minor things to note.
vmap
.lax.stop_gradient
. We never want to update these arrays via gradient descent.update_stats
flag, used to disable updates during inference.
(Neither of those are particularly weird, at least.)Oh, some other remarks:
BatchNorm
is now much less popular than it used to be -- it largely seems to have been supplanted by LayerNorm
etc. -- and moreover it's the only case I'm aware of in which we want to do some weird mutable-buffer stuff like this. So honestly I've been avoiding adding BatchNorm
to the library because it takes a little finesse to use without surprises, and one of the really nice things about Equinox so far (compared to Flax etc.) has been its lack of footguns and surprises.
Thank you for your detailed response, @patrick-kidger!
I think you're right that BatchNorm
doesn't mesh well with the functional way of doing things, and that LayerNorm
does the trick most of the time. It would still be nice to be able to recreate existing models with Equinox, however, but as you say it's very prone to error.
Perhaps it's possible to add a decorator marking mutability, in my changes named eqx.mutates
:
def pvar(x, axis_name):
mean = lax.pmean(x, axis_name)
var = lax.pmean((x - mean)**2, axis_name)
return var
class BatchNorm(eqx.Module):
running_mean: Optional[jnp.ndarray]
running_var: Optional[jnp.ndarray]
momentum: float
axis_name: str
eps: float
update_stats: bool
def __init__(self, momentum=0.99, axis_name="batch", eps=1e-5, update_stats=True, **kwargs):
super().__init__(**kwargs)
self.running_mean = None
self.running_var = None
self.momentum = momentum
self.axis_name = axis_name
self.eps = eps
self.update_stats = update_stats
# Marks that this function returns a mutated model.
# Will return a PyTree with only the updated variables alongside the result.
@eqx.mutates
def __call__(self, x, update_stats=None):
if update_stats is None:
update_stats = self.update_stats
mean = lax.pmean(x, axis=self.axis_name)
var = pvar(x, axis=self.axis_name)
if self.running_mean is None:
running_mean = mean
running_var = var
else:
running_mean = self.running_mean
running_var = self.running_var
running_mean = (1 - self.momentum) * mean + self.momentum * lax.stop_gradient(running_mean)
running_var = (1 - self.momentum) * var + self.momentum * lax.stop_gradient(running_var)
x = (x - running_mean) / jnp.sqrt(running_var + self.eps)
if self.update_stats:
self.mutate('running_mean', running_mean) # Raises an error if the method isn't decorated
self.mutate('running_var', running_var)
return x
I'm thinking that calling a mutates
method would return the results along with updates if not called from another mutates
context, in which case it'll register mutations to the parent instance. Finally, the updates that are returned can be merged with the model PyTree. I know this behavior is possible with a reasonable amount of magic in normal Python, but I'm not entirely sure what it might look like in Jax with JIT.
There are some other applications where having some mutable state is useful. For example, I've seen counters used for modules that can be called a variable amount of times when used in an adaptive step size ODE solver and modules that need to memorize what actions have been taken in reinforcement learning. If I'm not mistaken pretty much anything that needs to change with anything but gradients needs this kind of mutability, so I think implementing some solution would make Equinox a lot more versatile.
So I think the mutates
decorator you're suggesting is essentially equivalent to my first proposal. I don't think there's any benefit to using a mutates
decorator that delays the mutation until after the function has been traced. (And either way you still need to pass the mutated model out of the JIT'd region.)
I agree that changing with anything except gradients needs this kind of mutability. I'd love to have a neat solution to this in Equinox but I don't think there's a generic approach that's compatible with JAX's functional paradigm. The best ideas I've got in this space involve either creating custom interpreters or using XLA's infeed and outfeed (i.e. jax.experimental.host_callback.call
), and those are both relatively deep magic.
At least for adaptive step size ODE solvers, any decent solver should report the number of steps it took to you as a statistic. I think that kind of thing is true of most use-cases: there's usually some other way around it.
Here's a quick first-pass at the XLA infeed/outfeed approach.
import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.experimental.host_callback as hcb
cache = {}
shapes = {}
index = 0
def genindex():
global index
index += 1
return index
def _get_state(i):
return cache[i.item()]
def _set_state(i__v):
i, v = i__v
cache[i.item()] = v
def get_state(i):
try:
shape = shapes[i]
except KeyError as e:
raise RuntimeError from e
return hcb.call(_get_state, i, result_shape=shape)
def set_state(i, v):
try:
shape = shapes[i]
except KeyError:
shapes[i] = jax.ShapedArray(shape=v.shape, dtype=v.dtype)
else:
if shape.shape != v.shape or shape.dtype != v.dtype:
raise ValueError
v = lax.stop_gradient(v)
hcb.call(_set_state, (i, v))
class BatchNorm(eqx.Module):
mean_index: int = eqx.static_field()
momentum: float
def __init__(self):
self.mean_index = genindex()
set_state(self.mean_index, jnp.zeros(()))
self.momentum = 0.99
def __call__(self, x):
running_mean = (1 - self.momentum) * jnp.mean(x) + self.momentum * get_state(self.mean_index)
set_state(self.mean_index, running_mean)
return x - running_mean
model = BatchNorm()
@jax.jit
def forward(x):
return model(x)
print(forward(jnp.array(1.)))
print(forward(jnp.array(1.)))
print(forward(jnp.array(1.))) # Different results each time!
As you can see this is an essentially toy implementation for readability: it still only uses a default value for the running mean (zero) rather than the first evaluation of the model; it only subtracts off the mean and doesn't normalise by variance; it doesn't have a post-normalisation affine transformation. But those are all just non-technical implementation details, of course.
In practice this does still has a few technical details that need sorting out:
cache
and shapes
.But fundamentally I think this approach works.
CC @FedericoV as I recall you needed something similar -- you said you wanted to cache hints for optimisation problems. I recall mentioning the above as a solution to you; Idk if you ever implemented anything like it.
Well, I got successfully nerd-sniped into spending my weekend implementing this. (Mostly the new "stateful" technology that makes this possible.)
equinox.experimental.BatchNorm
now exists. Happy hacking.
Wow, amazing work, Patrick! This is brilliant! I'm definitely using Equinox for my project. I'm reading your implementation to understand a bit better how exactly hcb.call
allows for this mutability.
Sorry to continue this discussion on a closed issue, but I can't get the tests to succeed on my machine (M1 Mac). It seems jax.experimental.host_callback.call
causes a crash instead of raising an error when things go wrong.
Running this test https://github.com/patrick-kidger/equinox/blob/fcf6dadab0a3b99e9bcdc37e67073e097115cbd8/tests/test_stateful.py#L52-L56 results in a crash on this line https://github.com/patrick-kidger/equinox/blob/fcf6dadab0a3b99e9bcdc37e67073e097115cbd8/equinox/experimental/stateful.py#L139
When running in pytest the error output is blank with a pytest-specific stack trace, but running the test directly shows that Jax isn't catching the error from the pybind11
exception.
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/bin/python /Users/marcel/git/equinox/tests/test_stateful.py
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
ERROR:absl:Outside call <function _call.<locals>.<lambda> at 0x134476710> threw exception 'Cannot get state before it has been set'.
libc++abi: terminating with uncaught exception of type pybind11::error_already_set: KeyError: 'Cannot get state before it has been set'
At:
/Users/marcel/git/equinox/equinox/experimental/stateful.py(95): _get_state_hcb
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(719): <lambda>
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(1203): _outside_call_run_callback
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(1067): wrapped_callback
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/_src/dispatch.py(444): _execute_compiled
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/_src/dispatch.py(94): apply_primitive
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(960): _outside_call_impl
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/core.py(611): process_primitive
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/core.py(289): bind_with_trace
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/core.py(286): bind
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(739): _call
/opt/homebrew/Caskroom/miniforge/base/envs/equinox/lib/python3.10/site-packages/jax/experimental/host_callback.py(689): call
/Users/marcel/git/equinox/equinox/experimental/stateful.py(110): _get_state
/Users/marcel/git/equinox/equinox/experimental/stateful.py(140): get_state
/Users/marcel/git/equinox/tests/test_stateful.py(56): test_no_set
/Users/marcel/git/equinox/tests/test_stateful.py(164): <module>
Even when wrapping in try/except
for KeyError
this crashes the program.
I assume this is an issue with Jax/jaxlib on M1?
Yeah, this is a known issue with JAX -- namely, that host_callback.call
handles errors differently depending on OS, device, or phase of the moon.
See also https://github.com/google/jax/issues/9457
I don't think there's much that can be done about this one from the point of view of Equinox I'm afraid.
In Flax, Batch Normalization is a bit finicky since each call to
apply
requires markingbatch_stats
as mutable and updating thebatch_stats
afterward.How could this be implemented as a Module in Equinox? I'm happy to submit an implementation given some guidance.