patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.09k stars 140 forks source link

Optax with multiple optimizers #794

Open ToddMorrill opened 2 months ago

ToddMorrill commented 2 months ago

Optax has a function multi_transform, which is nice for using multiple optimizers. See here for a working example.

How should Equinox interact with multi_transform?

Suppose, for instance, I want to freeze the last layer of my network and I want to use optax's set_to_zero function. NB: I'm aware of this example, but it's less general than my question because it only addresses the case where you want to freeze a set of parameters (by excluding them from the differentiable set) and doesn't address the case where we want to apply a different optax optimizer to different sets of parameters.

My code might look something like:

def create_param_labels(model):
    # initialize everything with "train"
    param_labels = jax.tree.map(lambda x: "train", model)
    # set the mask for the last layer's weight to "freeze"
    param_labels = eqx.tree_at(lambda tree: tree.layers[-1].weight, param_mask, replace="freeze")
    return param_labels 
...
model_params, model_structure = eqx.partition(model, eqx.is_inexact_array)
param_labels = create_param_labels(model)
optimizer = optax.multi_transform({"train": optax.sgd(learning_rate=lr), "freeze": optax.set_to_zero()}, param_labels=param_labels)
optimizer_state = optimizer.init(model_params)

This results in an error because param_labels is a PyTree just like the model and the model has a __call__ method implemented, which makes it callable. So then under the hood optax will check if the mask (i.e., the param_labels) is callable. Since it is callable, it ends up calling param_labels(model_params), which isn't what we want.

We can't just delete the __call__ method from the param_labels object because it's a frozen dataclass. So it seems like the only way forward is to make param_labels a callable function that somehow labels PyTree nodes appropriately. I'm not enough of a PyTree pro to know if there's an easy way to manipulate PyTrees to easily target the weight matrix of the last layer. Does anyone have a sense for how to do this?

patrick-kidger commented 2 months ago

I think the easiest way to work around this Optax check is just to wrap your parameters into a non-callable pytree, e.g. a length-1 list: [your_pytree_here].

From an Optax point of view, Equinox models are just PyTrees of parameters, and you should be able to reason about them in the same way.

rdaems commented 1 month ago

I'm looking for the same functionality. Did you find an elegant solution in the mean time?

rdaems commented 1 month ago

Wrapping the parameters in a lenth-1 list works for me:

import jax
import jax.numpy as jnp
import equinox as eqx
import optax

class Model(eqx.Module):
    a: float
    b: float

    def __call__(self):
        return self.a + self.b

model = Model(jnp.array(7.), jnp.array(7.))

def loss_fn(model):
    return model()

def param_labels(listed_model):
    model = listed_model[0]
    labels = jax.tree_map(lambda _: 'a', model)
    labels = eqx.tree_at(lambda m: m.b, labels, 'b')
    return [labels]

optimizer = optax.multi_transform({'a': optax.adam(1e-2), 'b': optax.adam(1e-1)}, param_labels=param_labels)
opt_state = optimizer.init(eqx.filter([model], eqx.is_array))

for step in range(10):
    grads = eqx.filter_jit(eqx.filter_grad(loss_fn))(model)
    updates, opt_state = optimizer.update([grads], opt_state, [model])
    model = eqx.apply_updates(model, updates[0])
    print(model.a, model.b)