Closed LouisDesdoigts closed 2 years ago
Hey there. The good news is that optax.multi_transform
works fine with Equinox.
Regarding what you've written:
optax.multi_transform
works with any PyTree.Regarding how to use Equinox and optax.multi_transform
together, here's an example:
import equinox as eqx
import jax
import jax.random as jr
import optax
key1, key2 = jr.split(jr.PRNGKey(0))
mlp1 = eqx.nn.MLP(2, 2, 2, 2, key=key1)
mlp2 = eqx.nn.MLP(2, 2, 2, 2, key=key2)
# Example model. In its interaction with `optax.multi_transform`, all that matters
# is that it is some PyTree of parameters.
model = (mlp1, mlp2)
#
# Example 1: use different learning rates for different MLPs.
#
# PyTree prefix of `model`
param_spec = ("group1", "group2")
optim = optax.multi_transform(
{"group1": optax.adam(1e-1), "group2": optax.adam(1e-2)},
param_spec
)
# Filter as per https://docs.kidger.site/equinox/faq/#optax-is-throwing-an-error
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
#
# Example 2: use different learning rates for weights/biases.
#
# Use group1 by default
param_spec = jax.tree_map(lambda _: "group1", model)
# Use group2 for biases
has_bias = lambda x: hasattr(x, "bias")
where_bias = lambda m: tuple(x.bias for x in jax.tree_leaves(m, is_leaf=has_bias) if has_bias(x))
param_spec = eqx.tree_at(where_bias, param_spec, replace_fn=lambda _: "group2")
# Now the same use of `opax.multi_transform` as before
optim = optax.multi_transform(
{"group1": optax.adam(1e-1), "group2": optax.adam(1e-2)},
param_spec
)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
Hey, thanks for the quick response! Unfortunately this doesn't seem to have solved our issue.
I do now have a minimal example to reproduce the bug that should help much more this time! You you can find our package here (although you will need to install via the setup.py file, not PyPI) https://github.com/LouisDesdoigts/dLux
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as np
import numpy as onp
import equinox as eqx
import optax
import dLux
from dLux.base import *
from dLux.layers import *
from dLux.propagators import *
# Define input parameters
onp.random.seed(0)
npix, Nstars, Nims = 256, 10, 50
dev = 1e-5
fluxes = 1e6*np.array(onp.random.rand(Nstars))
positions = dev * np.array(2*onp.random.rand(2*Nstars).reshape([Nstars, 2]) - 1)
shifts = dev * np.array(2*onp.random.rand(2*Nims).reshape([Nims, 2]) - 1) * 0.1
wavels = np.linspace(400e-9, 600e-9, 5)
optical_layers = [
CreateWavefront(npix, 1.0),
NormaliseWavefront(npix),
ApplyOPD(npix, np.zeros([npix, npix])),
MFT(npix, npix, 1, 10., 1e-6),
Wavefront2PSF(npix)]
# Construct Model
osys = OpticalSystem(optical_layers)
model = Scene(osys, wavels, positions, fluxes, dithers=shifts)
images = model() # Test run the model
# Default values to group 0
param_spec = jax.tree_map(lambda _: "group0", model)
# Set parameter groups
param_spec = eqx.tree_at(lambda scene: scene.fluxes, param_spec, replace='group1')
param_spec = eqx.tree_at(lambda scene: scene.optical_system.layers[2].array, param_spec, replace='group2')
optim = optax.multi_transform(
{"group0": optax.adam(0.0),
"group1": optax.adam(1e-0),
"group2": optax.adam(1e-6),
},
param_spec
)
opt_state = optim.init(eqx.filter(param_spec, eqx.is_inexact_array))
That code produces this error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [7], in <cell line: 16>()
6 param_spec = eqx.tree_at(lambda scene: scene.optical_system.layers[2].array, param_spec, replace='group2')
8 optim = optax.multi_transform(
9 {"group0": optax.adam(0.0),
10 "group1": optax.adam(1e-0),
(...)
13 param_spec
14 )
---> 16 opt_state = optim.init(eqx.filter(param_spec, eqx.is_inexact_array))
File ~/mambaforge/envs/morph/lib/python3.8/site-packages/optax/_src/combine.py:127, in multi_transform.<locals>.init_fn(params)
126 def init_fn(params):
--> 127 labels = param_labels(params) if callable(param_labels) else param_labels
129 label_set = set(jax.tree_leaves(labels))
130 if not label_set.issubset(transforms.keys()):
TypeError: __call__() takes 1 positional argument but 2 were given
I am really not sure what is going on here, please let me know if you need any other information! Cheers
Looks like the callable
branch is triggering. Try wrapping both your model and param_spec
into a list of length 1. (Which is not a callable
.)
Fantastic! This has worked! Thanks so much for your help and for building such an awesome package :D
Hi Patrick,
We are currently developing an auto-diff optical modelling package that makes heavy use of Equinox (dLux, up on PyPi in the next few days)! We are however running into some issues when trying to optimize our models using optax and I believe the problem stems from the way Equinox is interacting with optax. We are also hoping that our use case can help build functionality into Equinox for optimizing more complex nested Equinox Modules with different optax optimizers/schedules.
The basic structure of the package is an Equinox.Module() in which some parameters are arrays, some are lists of other Equinox.Module() objects and others are Equinox.Modules() which contain lists of other Equinox.Modules. Like this:
The nature of these models is that each parameters needs to have a a different learning rate/optimization function. Optax has
optax.multi_transform()
to handle stitching together multiple optimizers, but this method does not work with our models because our parameters are stored in the equinox objects, not nested dictionaries. I have attempted to manually decompose our models into nested dictionaries, optimise the params in the dictionary and then rebuild a model from the updates parameters but it throws a very strange error:This method throws a very strange error:
Which I can not any sense of. Interestingly, I was able to recreate this error using the
eqx.tree_equal()
function called on the model and new_model objecteqx.tree_equal(new_model, model)
So I am hoping you can help shed some light, or help develop some functionality to support the optimization of nested Equniox Modules with different learning rates! I am completely new to optax and Equinox, and only familiar with Jax at a high level so its entirely possible I am misunderstanding something fundamental about the way these packages interact.
Finally, these are the versions of al the packages that I'm using:
jax: 0.3.4 eqx: 0.3.2 optax: 0.1.1
and it is all being run on an Apple M1 machine. Please let me know if you need any more information! Cheers