patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

Issues when applying multiple optax optimizers within nested Equinox Modules #79

Closed LouisDesdoigts closed 2 years ago

LouisDesdoigts commented 2 years ago

Hi Patrick,

We are currently developing an auto-diff optical modelling package that makes heavy use of Equinox (dLux, up on PyPi in the next few days)! We are however running into some issues when trying to optimize our models using optax and I believe the problem stems from the way Equinox is interacting with optax. We are also hoping that our use case can help build functionality into Equinox for optimizing more complex nested Equinox Modules with different optax optimizers/schedules.

The basic structure of the package is an Equinox.Module() in which some parameters are arrays, some are lists of other Equinox.Module() objects and others are Equinox.Modules() which contain lists of other Equinox.Modules. Like this:

Model(eqx.Module()):
 -> param1 = np.array([val0, val1, ....])
 -> param2 =  np.array([val0, val1, ....])
 -> param3 = [Layer1(eqx.Module()),
              Layer2(eqx.Module()),
              Layer3(eqx.Module())]

 -> param4 = SubModel(eqx.Module()):
      -> param1 = [Layer1(eqx.Module()),
                   Layer2(eqx.Module()),
                   Layer3(eqx.Module())]

The nature of these models is that each parameters needs to have a a different learning rate/optimization function. Optax has optax.multi_transform() to handle stitching together multiple optimizers, but this method does not work with our models because our parameters are stored in the equinox objects, not nested dictionaries. I have attempted to manually decompose our models into nested dictionaries, optimise the params in the dictionary and then rebuild a model from the updates parameters but it throws a very strange error:

model_dict = model_to_dict(model)
optimizer = optax.multi_transform(
    {'param1': optax.adam(1e-6),
    {'param2': optax.adam(1e3)}

opt_state = optimizer.init(model_dict)

# Still pass the scene object to the loss function
loss, grads = loss_func(model, data)

grads_dict = model_to_dict(grads)
updates, opt_state = optimizer.update(grads_dict, opt_state)
model_dict = eqx.apply_updates(model_dict, updates)

new_model = dict_to_model(model_dict)

# Error thrown here
loss, grads = loss_func(new_model, data)

This method throws a very strange error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [29], in <cell line: 11>()
     10 errors, grads_out, models = [], [], []
     11 for i in tqdm(range(5)):
---> 12     loss, grads = loss_func(opt_model, images)
     14     grads_dict = scene_to_dict(grads)
     16     # updates, state = tx.update(grads_dict, state, opt_dict)

File ~/mambaforge/envs/morph/lib/python3.8/functools.py:399, in partialmethod._make_unbound_method.<locals>._method(cls_or_self, *args, **keywords)
    397 def _method(cls_or_self, /, *args, **keywords):
    398     keywords = {**self.keywords, **keywords}
--> 399     return self.func(cls_or_self, *self.args, *args, **keywords)

File ~/mambaforge/envs/morph/lib/python3.8/site-packages/jax/_src/device_array.py:41, in _forward_method(attrname, self, fun, *args)
     40 def _forward_method(attrname, self, fun, *args):
---> 41   return fun(getattr(self, attrname), *args)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Which I can not any sense of. Interestingly, I was able to recreate this error using the eqx.tree_equal() function called on the model and new_model object eqx.tree_equal(new_model, model)

So I am hoping you can help shed some light, or help develop some functionality to support the optimization of nested Equniox Modules with different learning rates! I am completely new to optax and Equinox, and only familiar with Jax at a high level so its entirely possible I am misunderstanding something fundamental about the way these packages interact.

Finally, these are the versions of al the packages that I'm using:

jax: 0.3.4 eqx: 0.3.2 optax: 0.1.1

and it is all being run on an Apple M1 machine. Please let me know if you need any more information! Cheers

patrick-kidger commented 2 years ago

Hey there. The good news is that optax.multi_transform works fine with Equinox.

Regarding what you've written:

Regarding how to use Equinox and optax.multi_transform together, here's an example:

import equinox as eqx
import jax
import jax.random as jr
import optax

key1, key2 = jr.split(jr.PRNGKey(0))
mlp1 = eqx.nn.MLP(2, 2, 2, 2, key=key1)
mlp2 = eqx.nn.MLP(2, 2, 2, 2, key=key2)
# Example model. In its interaction with `optax.multi_transform`, all that matters
# is that it is some PyTree of parameters.
model = (mlp1, mlp2)

#
# Example 1: use different learning rates for different MLPs.
#

# PyTree prefix of `model`
param_spec = ("group1", "group2")
optim = optax.multi_transform(
    {"group1": optax.adam(1e-1), "group2": optax.adam(1e-2)},
    param_spec
)
# Filter as per https://docs.kidger.site/equinox/faq/#optax-is-throwing-an-error
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

#
# Example 2: use different learning rates for weights/biases.
#

# Use group1 by default
param_spec = jax.tree_map(lambda _: "group1", model)
# Use group2 for biases
has_bias = lambda x: hasattr(x, "bias")
where_bias = lambda m: tuple(x.bias for x in jax.tree_leaves(m, is_leaf=has_bias) if has_bias(x))
param_spec = eqx.tree_at(where_bias, param_spec, replace_fn=lambda _: "group2")
# Now the same use of `opax.multi_transform` as before
optim = optax.multi_transform(
    {"group1": optax.adam(1e-1), "group2": optax.adam(1e-2)},
    param_spec
)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
LouisDesdoigts commented 2 years ago

Hey, thanks for the quick response! Unfortunately this doesn't seem to have solved our issue.

I do now have a minimal example to reproduce the bug that should help much more this time! You you can find our package here (although you will need to install via the setup.py file, not PyPI) https://github.com/LouisDesdoigts/dLux

import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as np
import numpy as onp
import equinox as eqx
import optax

import dLux
from dLux.base import *
from dLux.layers import *
from dLux.propagators import *

# Define input parameters
onp.random.seed(0)
npix, Nstars, Nims = 256, 10, 50
dev = 1e-5
fluxes = 1e6*np.array(onp.random.rand(Nstars))
positions = dev * np.array(2*onp.random.rand(2*Nstars).reshape([Nstars, 2]) - 1)
shifts = dev * np.array(2*onp.random.rand(2*Nims).reshape([Nims, 2]) - 1) * 0.1
wavels = np.linspace(400e-9, 600e-9, 5)

optical_layers = [
    CreateWavefront(npix, 1.0),
    NormaliseWavefront(npix),
    ApplyOPD(npix, np.zeros([npix, npix])),
    MFT(npix, npix, 1, 10., 1e-6),
    Wavefront2PSF(npix)]

# Construct Model
osys = OpticalSystem(optical_layers)
model = Scene(osys, wavels, positions, fluxes, dithers=shifts)
images = model() # Test run the model

# Default values to group 0
param_spec = jax.tree_map(lambda _: "group0", model)

# Set parameter groups
param_spec = eqx.tree_at(lambda scene: scene.fluxes, param_spec, replace='group1')
param_spec = eqx.tree_at(lambda scene: scene.optical_system.layers[2].array, param_spec, replace='group2')

optim = optax.multi_transform(
    {"group0": optax.adam(0.0), 
     "group1": optax.adam(1e-0),
     "group2": optax.adam(1e-6),
    },
    param_spec
)

opt_state = optim.init(eqx.filter(param_spec, eqx.is_inexact_array))

That code produces this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [7], in <cell line: 16>()
      6 param_spec = eqx.tree_at(lambda scene: scene.optical_system.layers[2].array, param_spec, replace='group2')
      8 optim = optax.multi_transform(
      9     {"group0": optax.adam(0.0), 
     10      "group1": optax.adam(1e-0),
   (...)
     13     param_spec
     14 )
---> 16 opt_state = optim.init(eqx.filter(param_spec, eqx.is_inexact_array))

File ~/mambaforge/envs/morph/lib/python3.8/site-packages/optax/_src/combine.py:127, in multi_transform.<locals>.init_fn(params)
    126 def init_fn(params):
--> 127   labels = param_labels(params) if callable(param_labels) else param_labels
    129   label_set = set(jax.tree_leaves(labels))
    130   if not label_set.issubset(transforms.keys()):

TypeError: __call__() takes 1 positional argument but 2 were given

I am really not sure what is going on here, please let me know if you need any other information! Cheers

patrick-kidger commented 2 years ago

Looks like the callable branch is triggering. Try wrapping both your model and param_spec into a list of length 1. (Which is not a callable.)

LouisDesdoigts commented 2 years ago

Fantastic! This has worked! Thanks so much for your help and for building such an awesome package :D