Closed haydn-jones closed 2 months ago
Quite likely this is some version of https://docs.kidger.site/equinox/faq/#optax-throwing-a-typeerror
The MNIST example already filters by is_array
, though I also updated it to is_inexact_array
and had the same issue.
EDIT: With my model, changing from is_array
to is_inexact_array
fixed it, thanks for pointing that out! I think some of the state in dropout was getting traced that shouldn't have been, makes sense. Still don't see what I'm doing wrong with the MNIST example though.
state in dropout was getting traced that shouldn't have been
Dropout seems to be a common problem child (see https://github.com/patrick-kidger/equinox/issues/772 and https://github.com/patrick-kidger/equinox/issues/681). If you post the MVC for the MNIST that could help
Pared down MNIST notebook with change to MultiSteps (line 71):
import equinox as eqx
import jax
import jax.numpy as jnp
import optax # https://github.com/deepmind/optax
import torch # https://pytorch.org
import torchvision # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree # https://github.com/google/jaxtyping
BATCH_SIZE = 64
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678
key = jax.random.PRNGKey(SEED)
normalise_data = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)),
]
)
train_dataset = torchvision.datasets.MNIST(
"MNIST",
train=True,
download=True,
transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
class CNN(eqx.Module):
layers: list
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
self.layers = [
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
eqx.nn.MaxPool2d(kernel_size=2),
jax.nn.relu,
jnp.ravel,
eqx.nn.Linear(1728, 512, key=key2),
jax.nn.sigmoid,
eqx.nn.Linear(512, 64, key=key3),
jax.nn.relu,
eqx.nn.Linear(64, 10, key=key4),
jax.nn.log_softmax,
]
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
for layer in self.layers:
x = layer(x)
return x
key, subkey = jax.random.split(key, 2)
model = CNN(subkey)
def loss(model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]) -> Float[Array, ""]:
pred_y = jax.vmap(model)(x)
return cross_entropy(y, pred_y)
def cross_entropy(y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]) -> Float[Array, ""]:
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
return -jnp.mean(pred_y)
# optim = optax.adamw(LEARNING_RATE)
optim = optax.MultiSteps(optax.adamw(LEARNING_RATE), every_k_schedule=4)
def train(
model: CNN,
trainloader: torch.utils.data.DataLoader,
optim: optax.GradientTransformation,
steps: int,
print_every: int,
) -> CNN:
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
@eqx.filter_jit
def make_step(
model: CNN,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"],
):
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss_value
def infinite_trainloader():
while True:
yield from trainloader
for step, (x, y) in zip(range(steps), infinite_trainloader()):
x = x.numpy()
y = y.numpy()
model, opt_state, train_loss = make_step(model, opt_state, x, y)
if (step % print_every) == 0 or (step == steps - 1):
print(f"{step=}, train_loss={train_loss.item()}, ")
return model
model = train(model, trainloader, optim, STEPS, PRINT_EVERY)
The fix I saw was just to filter the model during the update, updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
worked for me. The reason I believe is this. In normal Adam, things get tree mapped over the non filtered model, but that's ok because optax has the gradients as the first parameter in the tree map (which is the tree over which jax determines leaves, if you swap them it would break since the None's in the gradients are then counted as leaves). But in multi step, you have a cond which the model is input to (https://github.com/google-deepmind/optax/blob/main/optax/transforms/_accumulation.py#L380). And thus, if you don't filter you get an error because it tries to cond a pytree that contains non jax types.
Oh yep thats correct, thanks! Closing this.
I'm having a hard time getting
optax.MultiSteps
to work. With my own model I was getting an error that made it seem like the optimizers tree structure was changing between updates, so to simplify the issue I tried wrapping the optimizer in themnist.ipynb
example withoptim = optax.MultiSteps(..., every_k_schedule=4)
. This resulted in a very different error ofIts unclear to me if this is user error (or an issue with optax/equinox), but I don't remember any special changes needed when wrapping an optimizer in MultiSteps. Should the mnist example not just work when wrapped like that?