paganpasta / eqxvision

A Python package of computer vision models for the Equinox ecosystem.
https://eqxvision.readthedocs.io
MIT License
100 stars 12 forks source link

Fix batch norm #73

Closed hlzl closed 1 year ago

hlzl commented 1 year ago

Trying to revive PR #71 and solve issue #70. These changes should complete the refactor and use the new state mechanic. However, currently this causes the following error in the Sequential module:

File "/home/user/eqxvision/eqx_train.py", line 296, in forward
    pred_ys, state = batch_model(images, state, key=keys)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/eqxvision/models/classification/resnet.py", line 352, in __call__
    x, state = self.layer1(x, state, key=keys[1])
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_sequential.py", line 68, in __call__
    x = layer(x, key=key)
        ^^^^^^^^^^^^^^^^^
TypeError: _ResNetBasicBlock.__call__() missing 1 required positional argument: 'state'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Can be reproduced with:

import jax
import equinox as eqx
from eqxvision.models import resnet18

@eqx.filter_jit
def forward(model, state, images):
    keys = jax.random.split(jax.random.PRNGKey(0), images.shape[0])
    batch_model = jax.vmap(
        model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
    )
    pred_ys, state = batch_model(images, state, key=keys)
    return pred_ys

net = resnet18()
state = eqx.nn.State(net)

images = jax.random.uniform(jax.random.PRNGKey(0), shape=(1, 3, 64, 64))
output = forward(net, state, images)

@patrick-kidger Any idea how to solve this? Should we get rid of the if-else statement causing this error that checks for isinstance(layer, StatefulLayer) in _sequential?

hlzl commented 1 year ago

As a side note, not all models have been fixed. I don't use the following models and thus don't know a suitable fix. Might be interesting for someone later:

hlzl commented 1 year ago

Seems like we can fix the isinstance(layer, StatefulLayer) error in the Sequential module if we inherit from StatefulLayer in our layer-block classes (such as _ResNetBasicBlock(), _DenseLayer()). This way the sequential module actually registers the stateful layers and the necessity to pass in the state.

This, however, doesn't seem to work straight forwardly for more nested models such as the EfficientNet implementation.

Maybe it would be a better idea to set some kind of flag if a sequential layer has any child that requires a state? @paganpasta

patrick-kidger commented 1 year ago

Seems like we can fix the isinstance(layer, StatefulLayer) error in the Sequential module if we inherit from StatefulLayer in our layer-block classes

Yup! This is the expected fix.

This, however, doesn't seem to work straight forwardly for more nested models such as the EfficientNet implementation. Maybe it would be a better idea to set some kind of flag if a sequential layer has any child that requires a state?

If you have a child stateful layer, then the parent itself is necessarily also stateful as well -- they should subclass StatefulLayer, accept a state argument, return a state argument -- and pipe it to their child layer in between.

hlzl commented 1 year ago

Maybe I'm missing something here, but if I have a nested sequential layer, how should the top sequential layer know that inside its child sequential layer there is a stateful layer? The if isinstance(layer, StatefulLayer) condition is not recursive and will just see the child layer as being of type Sequential, right?

At least that seems to be what happened when I tried to change the EfficientNet implementation by converting _MBConv and _FusedMBConv to inherit from StatefulLayer, causing the following error:

File "/home/user/eqxvision/eqx_train.py", line 12, in forward
    pred_ys, state = batch_model(images, state, key=keys)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/eqxvision/models/classification/efficientnet.py", line 405, in __call__
  x, state = self.features(x, state, key=keys[0])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
  return func(*args, **kwds)
         ^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_sequential.py", line 68, in __call__
  x = layer(x, key=key)
        ^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
  return func(*args, **kwds)
         ^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_sequential.py", line 66, in __call__
  x, state = layer(x, state=state, key=key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
  return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_batch_norm.py", line 155, in __call__
  first_time = state.get(self.first_time_index)
            ^^^^^^^^^
AttributeError: 'object' object has no attribute 'get'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
patrick-kidger commented 1 year ago

Ah, I take your point -- you'd like a way to programatically propagate statefulness through, since the containing classes (whether they are a Sequential or a custom class) may-or-may-not be stateful, depending on their choice of sublayers.

I've just written https://github.com/patrick-kidger/equinox/pull/505, which adds a StatefulLayer.is_stateful method, which is called to check whether or not a layer is stateful. And indeed Sequential now inherits from StatefulLayer, with an implementation of

def is_stateful(self):
    return any(isinstance(x, StatefulLayer) and x.is_stateful() for x in self.layers)

Thus nested sequentials should now automatically work. In addition, you should be able to have your own classes inherit from StatefulLayer, and implement is_stateful, so that if you wish your custom layers can also be handled statefully when placed inside a Sequential, if required.

Does that seem like it would work for you? If you can install Equinox from that branch and check that it meets your needs, then I'll include it in the upcoming release.

hlzl commented 1 year ago

That's exactly what I meant and your proposed solutions sounds great!

I haven't been able to complete my testing as I ran into the Cannot assign methods in __init__ error introduced in v0.11.0 (gets thrown e.g. here for the ResNet).

As it seems that it's common in a lot of models in eqxvision to create layer blocks in __init__ with class functions (probably taken from the equivalent pytorch implementations), I was wondering what would be the best approach to refactor this before I do.

We don't want to create the layer blocks anew every time we call them, so I thought about using functools.chached_property. This still seems like a lot of additional code compared to the previous implementations, so I thought you might had a better idea in mind when implementing the error that I haven't thought of yet.

patrick-kidger commented 1 year ago

Great, I'm glad it works. I've just merged that chagne into dev.

As for that error -- that's a bug in something I just wrote, whoops. I think https://github.com/patrick-kidger/equinox/pull/508 should fix.

hlzl commented 1 year ago

All models using BatchNorm should work now. Tested with equinox/dev commit d9b018a.

paganpasta commented 1 year ago

@hlzl Hi, thanks for the PR and sorry for the late response.

I think the tests are currently failing cause the equinox updates are not yet packaged into a new release.

I'll test it locally, update and merge accordingly, soon. Currently, tied down this week with few deadlines.

patrick-kidger commented 1 year ago

I'm planning on doing the next Equinox release in the next week, by the way.

hlzl commented 1 year ago

No worries, thank you @paganpasta.

The models generally run, but it seems like they do not reproduce the same results as their PyTorch equivalent.

E.g., trying to reproduce results on CIFAR10 with a PyTorch ResNet18 with an equivalent equinox implementation: Somehow, the equinox implementation stops learning anything useful after matching the torch implementation during the first (and second) epoch, and then eventually converges back to random guessing.

The gradients behave super strangely starting in the third epoch, but I'm not able to tell what causes the issue and leads to an exploding loss. In particular, because this setup should be identical to the PyTorch one, where this problem does not arise.

Examples to reproduce with current dev branch of equinox and torch==2.0:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from tqdm import tqdm

def accuracy(outputs, labels):
    _, predicted = torch.max(outputs.data, 1)
    correct = (predicted == labels).sum().item()
    total = labels.size(0)
    return correct / total

######################################################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

batch_size = 256
num_epochs = 100

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

model = torchvision.models.resnet18(pretrained=False)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(num_epochs):
    running_loss = 0.0
    running_accuracy = 0.0
    for images, labels in tqdm(
        trainloader, leave=False, desc=f"Epoch {epoch+1}", total=len(trainloader)
    ):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_accuracy += accuracy(outputs, labels) * 100

    epoch_loss = running_loss / len(trainloader)
    epoch_accuracy = running_accuracy / len(trainloader)
    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}"
    )

vs.

import equinox as eqx
import eqxvision

import jax
import jax.numpy as jnp
import optax

import torch
import torchvision
from tqdm import tqdm

def xavier_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
    nin = weight.shape[1]
    return jax.random.normal(key, weight.shape) * jnp.sqrt(2.0 / nin)

def init_weights(model, init_fn, key):
    is_weight = lambda x: isinstance(x, (eqx.nn.Linear, eqx.nn.Conv))
    get_weights = lambda m: [
        x.weight
        for x in jax.tree_util.tree_leaves(m, is_leaf=is_weight)
        if is_weight(x)
    ]
    weights = get_weights(model)
    new_weights = [
        init_fn(weight, subkey)
        for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
    ]
    new_model = eqx.tree_at(get_weights, model, new_weights)
    return new_model

@eqx.filter_jit
def loss(model, state, x, label):
    keys = jax.random.split(jax.random.PRNGKey(5678), x.shape[0])
    batch_model = jax.vmap(
        model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
    )
    pred_ys, state = batch_model(x, state, key=keys)
    return optax.softmax_cross_entropy_with_integer_labels(pred_ys, label).mean(), state

@eqx.filter_jit
def make_step(model, state, opt_state, x, label):
    (val, state), grads = eqx.filter_value_and_grad(loss, has_aux=True)(
        model, state, x, label
    )
    updates, opt_state = opt.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, state, opt_state, val

@eqx.filter_jit
def inference(model, state, x):
    keys = jax.random.split(jax.random.PRNGKey(5678), x.shape[0])
    inference_model = eqx.Partial(eqx.tree_inference(model, value=True), state=state)
    return jax.vmap(inference_model)(x, key=keys)

@eqx.filter_jit
def accuracy(outputs, labels):
    predicted = jnp.argmax(outputs, 1)
    correct = (predicted == labels).sum()
    total = labels.size
    return correct / total

######################################################

batch_size = 256
num_epochs = 100

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

model = eqxvision.models.resnet18(num_classes=10)
model = init_weights(model, xavier_init, key=jax.random.PRNGKey(5678))

opt = optax.sgd(learning_rate=0.01, momentum=0.9)

state = eqx.nn.State(model)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))
for epoch in range(0, num_epochs):
    running_loss, running_accuracy = 0.0, 0.0
    for images, labels in tqdm(
        trainloader, leave=False, desc=f"Epoch {epoch+1}", total=len(trainloader)
    ):
        model, state, opt_state, loss_val = make_step(
            model,
            state,
            opt_state,
            images.numpy(),
            labels.numpy(),
        )

        out = inference(model, state, images.numpy())

        running_loss += loss_val
        running_accuracy += accuracy(out[0], labels.numpy()) * 100

    epoch_loss = running_loss / len(trainloader)
    epoch_accuracy = running_accuracy / len(trainloader)
    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}"
    )

Would be great if one of you could have a look here.

patrick-kidger commented 1 year ago

Hmm two quick thoughts:

FWIW PyTorch and Equinox use slightly different batch norm implementations (there are many variants). But once you've removed batch norm, you could try initialising them with the same weights, and use the same training batches, and see if you can get close to bit-for-bit reproducibility.

patrick-kidger commented 1 year ago

Btw, heads-up that Equinox v0.11.0 is now released! That shouldn't be a blocker any more for this PR.

hlzl commented 1 year ago

@patrick-kidger Sorry for the late reply. Regrading your two thoughts:

BTW, my simple test for comparing the two implementations was to overfit on CIFAR10 during training. This generally works out of the box on most ResNet implementations, but somehow I was not able to achieve this for the equinox implementation with batch norm. Do you have any idea why this could be (or if this is due to the different bn variant)?