Closed hlzl closed 1 year ago
As a side note, not all models have been fixed. I don't use the following models and thus don't know a suitable fix. Might be interesting for someone later:
Seems like we can fix the isinstance(layer, StatefulLayer)
error in the Sequential
module if we inherit from StatefulLayer
in our layer-block classes (such as _ResNetBasicBlock()
, _DenseLayer()
). This way the sequential module actually registers the stateful layers and the necessity to pass in the state.
This, however, doesn't seem to work straight forwardly for more nested models such as the EfficientNet
implementation.
Maybe it would be a better idea to set some kind of flag if a sequential layer has any child that requires a state? @paganpasta
Seems like we can fix the isinstance(layer, StatefulLayer) error in the Sequential module if we inherit from StatefulLayer in our layer-block classes
Yup! This is the expected fix.
This, however, doesn't seem to work straight forwardly for more nested models such as the EfficientNet implementation. Maybe it would be a better idea to set some kind of flag if a sequential layer has any child that requires a state?
If you have a child stateful layer, then the parent itself is necessarily also stateful as well -- they should subclass StatefulLayer
, accept a state
argument, return a state
argument -- and pipe it to their child layer in between.
Maybe I'm missing something here, but if I have a nested sequential layer, how should the top sequential layer know that inside its child sequential layer there is a stateful layer?
The if isinstance(layer, StatefulLayer)
condition is not recursive and will just see the child layer as being of type Sequential
, right?
At least that seems to be what happened when I tried to change the EfficientNet implementation by converting _MBConv
and _FusedMBConv
to inherit from StatefulLayer
, causing the following error:
File "/home/user/eqxvision/eqx_train.py", line 12, in forward
pred_ys, state = batch_model(images, state, key=keys)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/eqxvision/models/classification/efficientnet.py", line 405, in __call__
x, state = self.features(x, state, key=keys[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_sequential.py", line 68, in __call__
x = layer(x, key=key)
^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_sequential.py", line 66, in __call__
x, state = layer(x, state=state, key=key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/equinox/lib/python3.11/site-packages/equinox/nn/_batch_norm.py", line 155, in __call__
first_time = state.get(self.first_time_index)
^^^^^^^^^
AttributeError: 'object' object has no attribute 'get'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Ah, I take your point -- you'd like a way to programatically propagate statefulness through, since the containing classes (whether they are a Sequential
or a custom class) may-or-may-not be stateful, depending on their choice of sublayers.
I've just written https://github.com/patrick-kidger/equinox/pull/505, which adds a StatefulLayer.is_stateful
method, which is called to check whether or not a layer is stateful. And indeed Sequential
now inherits from StatefulLayer
, with an implementation of
def is_stateful(self):
return any(isinstance(x, StatefulLayer) and x.is_stateful() for x in self.layers)
Thus nested sequentials should now automatically work. In addition, you should be able to have your own classes inherit from StatefulLayer
, and implement is_stateful
, so that if you wish your custom layers can also be handled statefully when placed inside a Sequential
, if required.
Does that seem like it would work for you? If you can install Equinox from that branch and check that it meets your needs, then I'll include it in the upcoming release.
That's exactly what I meant and your proposed solutions sounds great!
I haven't been able to complete my testing as I ran into the Cannot assign methods in __init__
error introduced in v0.11.0
(gets thrown e.g. here for the ResNet).
As it seems that it's common in a lot of models in eqxvision
to create layer blocks in __init__
with class functions (probably taken from the equivalent pytorch implementations), I was wondering what would be the best approach to refactor this before I do.
We don't want to create the layer blocks anew every time we call them, so I thought about using functools.chached_property
.
This still seems like a lot of additional code compared to the previous implementations, so I thought you might had a better idea in mind when implementing the error that I haven't thought of yet.
Great, I'm glad it works. I've just merged that chagne into dev
.
As for that error -- that's a bug in something I just wrote, whoops. I think https://github.com/patrick-kidger/equinox/pull/508 should fix.
All models using BatchNorm
should work now. Tested with equinox/dev
commit d9b018a.
@hlzl Hi, thanks for the PR and sorry for the late response.
I think the tests are currently failing cause the equinox updates are not yet packaged into a new release.
I'll test it locally, update and merge accordingly, soon. Currently, tied down this week with few deadlines.
I'm planning on doing the next Equinox release in the next week, by the way.
No worries, thank you @paganpasta.
The models generally run, but it seems like they do not reproduce the same results as their PyTorch equivalent.
E.g., trying to reproduce results on CIFAR10 with a PyTorch ResNet18 with an equivalent equinox
implementation:
Somehow, the equinox
implementation stops learning anything useful after matching the torch implementation during the first (and second) epoch, and then eventually converges back to random guessing.
The gradients behave super strangely starting in the third epoch, but I'm not able to tell what causes the issue and leads to an exploding loss. In particular, because this setup should be identical to the PyTorch one, where this problem does not arise.
Examples to reproduce with current dev
branch of equinox
and torch==2.0
:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from tqdm import tqdm
def accuracy(outputs, labels):
_, predicted = torch.max(outputs.data, 1)
correct = (predicted == labels).sum().item()
total = labels.size(0)
return correct / total
######################################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)
batch_size = 256
num_epochs = 100
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
model = torchvision.models.resnet18(pretrained=False)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(num_epochs):
running_loss = 0.0
running_accuracy = 0.0
for images, labels in tqdm(
trainloader, leave=False, desc=f"Epoch {epoch+1}", total=len(trainloader)
):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
running_accuracy += accuracy(outputs, labels) * 100
epoch_loss = running_loss / len(trainloader)
epoch_accuracy = running_accuracy / len(trainloader)
print(
f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}"
)
vs.
import equinox as eqx
import eqxvision
import jax
import jax.numpy as jnp
import optax
import torch
import torchvision
from tqdm import tqdm
def xavier_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
nin = weight.shape[1]
return jax.random.normal(key, weight.shape) * jnp.sqrt(2.0 / nin)
def init_weights(model, init_fn, key):
is_weight = lambda x: isinstance(x, (eqx.nn.Linear, eqx.nn.Conv))
get_weights = lambda m: [
x.weight
for x in jax.tree_util.tree_leaves(m, is_leaf=is_weight)
if is_weight(x)
]
weights = get_weights(model)
new_weights = [
init_fn(weight, subkey)
for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
]
new_model = eqx.tree_at(get_weights, model, new_weights)
return new_model
@eqx.filter_jit
def loss(model, state, x, label):
keys = jax.random.split(jax.random.PRNGKey(5678), x.shape[0])
batch_model = jax.vmap(
model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
)
pred_ys, state = batch_model(x, state, key=keys)
return optax.softmax_cross_entropy_with_integer_labels(pred_ys, label).mean(), state
@eqx.filter_jit
def make_step(model, state, opt_state, x, label):
(val, state), grads = eqx.filter_value_and_grad(loss, has_aux=True)(
model, state, x, label
)
updates, opt_state = opt.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, state, opt_state, val
@eqx.filter_jit
def inference(model, state, x):
keys = jax.random.split(jax.random.PRNGKey(5678), x.shape[0])
inference_model = eqx.Partial(eqx.tree_inference(model, value=True), state=state)
return jax.vmap(inference_model)(x, key=keys)
@eqx.filter_jit
def accuracy(outputs, labels):
predicted = jnp.argmax(outputs, 1)
correct = (predicted == labels).sum()
total = labels.size
return correct / total
######################################################
batch_size = 256
num_epochs = 100
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
model = eqxvision.models.resnet18(num_classes=10)
model = init_weights(model, xavier_init, key=jax.random.PRNGKey(5678))
opt = optax.sgd(learning_rate=0.01, momentum=0.9)
state = eqx.nn.State(model)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))
for epoch in range(0, num_epochs):
running_loss, running_accuracy = 0.0, 0.0
for images, labels in tqdm(
trainloader, leave=False, desc=f"Epoch {epoch+1}", total=len(trainloader)
):
model, state, opt_state, loss_val = make_step(
model,
state,
opt_state,
images.numpy(),
labels.numpy(),
)
out = inference(model, state, images.numpy())
running_loss += loss_val
running_accuracy += accuracy(out[0], labels.numpy()) * 100
epoch_loss = running_loss / len(trainloader)
epoch_accuracy = running_accuracy / len(trainloader)
print(
f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}"
)
Would be great if one of you could have a look here.
Hmm two quick thoughts:
FWIW PyTorch and Equinox use slightly different batch norm implementations (there are many variants). But once you've removed batch norm, you could try initialising them with the same weights, and use the same training batches, and see if you can get close to bit-for-bit reproducibility.
Btw, heads-up that Equinox v0.11.0 is now released! That shouldn't be a blocker any more for this PR.
@patrick-kidger Sorry for the late reply. Regrading your two thoughts:
BTW, my simple test for comparing the two implementations was to overfit on CIFAR10 during training. This generally works out of the box on most ResNet implementations, but somehow I was not able to achieve this for the equinox implementation with batch norm. Do you have any idea why this could be (or if this is due to the different bn variant)?
Trying to revive PR #71 and solve issue #70. These changes should complete the refactor and use the new
state
mechanic. However, currently this causes the following error in theSequential
module:Can be reproduced with:
@patrick-kidger Any idea how to solve this? Should we get rid of the if-else statement causing this error that checks for
isinstance(layer, StatefulLayer)
in_sequential
?