patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.13k stars 143 forks source link

optax.MultiSteps #853

Closed haydn-jones closed 2 months ago

haydn-jones commented 2 months ago

I'm having a hard time getting optax.MultiSteps to work. With my own model I was getting an error that made it seem like the optimizers tree structure was changing between updates, so to simplify the issue I tried wrapping the optimizer in the mnist.ipynb example with optim = optax.MultiSteps(..., every_k_schedule=4). This resulted in a very different error of

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[16], line 1
----> 1 model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)

Cell In[15], line 38
     36 x = x.numpy()
     37 y = y.numpy()
---> 38 model, opt_state, train_loss = make_step(model, opt_state, x, y)
     39 if (step % print_every) == 0 or (step == steps - 1):
     40     test_loss, test_accuracy = evaluate(model, testloader)

    [... skipping hidden 15 frame]

Cell In[15], line 24
     16 @eqx.filter_jit
     17 def make_step(
     18     model: CNN,
   (...)
     21     y: Int[Array, " batch"],
     22 ):
     23     loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
---> 24     updates, opt_state = optim.update(grads, opt_state, model)
     25     model = eqx.apply_updates(model, updates)
     26     return model, opt_state, loss_value

File .venv/lib/python3.12/site-packages/optax/transforms/_accumulation.py:380, in MultiSteps.update(self, updates, state, params, **extra_args)
    377   zero_updates = otu.tree_zeros_like(state.acc_grads)
    378   return zero_updates, multi_state_when_skip
--> 380 new_updates, new_state = lax.cond(
    381     should_skip_update, _skip_update, _do_update, *(updates, state, params)
    382 )
    383 return new_updates, new_state

    [... skipping hidden 5 frame]

File .venv/lib/python3.12/site-packages/jax/_src/core.py:1508, in concrete_aval(x)
   1506 if hasattr(x, '__jax_array__'):
   1507   return concrete_aval(x.__jax_array__())
-> 1508 raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
   1509                  "type")

TypeError: Value <function max at 0x7ff5b87ed580> with type <class 'function'> is not a valid JAX type

Its unclear to me if this is user error (or an issue with optax/equinox), but I don't remember any special changes needed when wrapping an optimizer in MultiSteps. Should the mnist example not just work when wrapped like that?

patrick-kidger commented 2 months ago

Quite likely this is some version of https://docs.kidger.site/equinox/faq/#optax-throwing-a-typeerror

haydn-jones commented 2 months ago

The MNIST example already filters by is_array, though I also updated it to is_inexact_array and had the same issue.

EDIT: With my model, changing from is_array to is_inexact_array fixed it, thanks for pointing that out! I think some of the state in dropout was getting traced that shouldn't have been, makes sense. Still don't see what I'm doing wrong with the MNIST example though.

lockwo commented 2 months ago

state in dropout was getting traced that shouldn't have been

Dropout seems to be a common problem child (see https://github.com/patrick-kidger/equinox/issues/772 and https://github.com/patrick-kidger/equinox/issues/681). If you post the MVC for the MNIST that could help

haydn-jones commented 2 months ago

Pared down MNIST notebook with change to MultiSteps (line 71):

import equinox as eqx
import jax
import jax.numpy as jnp
import optax  # https://github.com/deepmind/optax
import torch  # https://pytorch.org
import torchvision  # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree  # https://github.com/google/jaxtyping

BATCH_SIZE = 64
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678

key = jax.random.PRNGKey(SEED)

normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
)
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

class CNN(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x

key, subkey = jax.random.split(key, 2)
model = CNN(subkey)

def loss(model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]) -> Float[Array, ""]:
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)

def cross_entropy(y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]) -> Float[Array, ""]:
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)

# optim = optax.adamw(LEARNING_RATE)
optim = optax.MultiSteps(optax.adamw(LEARNING_RATE), every_k_schedule=4)

def train(
    model: CNN,
    trainloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
) -> CNN:
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

    @eqx.filter_jit
    def make_step(
        model: CNN,
        opt_state: PyTree,
        x: Float[Array, "batch 1 28 28"],
        y: Int[Array, " batch"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    def infinite_trainloader():
        while True:
            yield from trainloader

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
        if (step % print_every) == 0 or (step == steps - 1):
            print(f"{step=}, train_loss={train_loss.item()}, ")
    return model

model = train(model, trainloader, optim, STEPS, PRINT_EVERY)
lockwo commented 2 months ago

The fix I saw was just to filter the model during the update, updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array)) worked for me. The reason I believe is this. In normal Adam, things get tree mapped over the non filtered model, but that's ok because optax has the gradients as the first parameter in the tree map (which is the tree over which jax determines leaves, if you swap them it would break since the None's in the gradients are then counted as leaves). But in multi step, you have a cond which the model is input to (https://github.com/google-deepmind/optax/blob/main/optax/transforms/_accumulation.py#L380). And thus, if you don't filter you get an error because it tries to cond a pytree that contains non jax types.

haydn-jones commented 2 months ago

Oh yep thats correct, thanks! Closing this.