Open ToddMorrill opened 2 months ago
I think the easiest way to work around this Optax check is just to wrap your parameters into a non-callable pytree, e.g. a length-1 list: [your_pytree_here]
.
From an Optax point of view, Equinox models are just PyTrees of parameters, and you should be able to reason about them in the same way.
I'm looking for the same functionality. Did you find an elegant solution in the mean time?
Wrapping the parameters in a lenth-1 list works for me:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
class Model(eqx.Module):
a: float
b: float
def __call__(self):
return self.a + self.b
model = Model(jnp.array(7.), jnp.array(7.))
def loss_fn(model):
return model()
def param_labels(listed_model):
model = listed_model[0]
labels = jax.tree_map(lambda _: 'a', model)
labels = eqx.tree_at(lambda m: m.b, labels, 'b')
return [labels]
optimizer = optax.multi_transform({'a': optax.adam(1e-2), 'b': optax.adam(1e-1)}, param_labels=param_labels)
opt_state = optimizer.init(eqx.filter([model], eqx.is_array))
for step in range(10):
grads = eqx.filter_jit(eqx.filter_grad(loss_fn))(model)
updates, opt_state = optimizer.update([grads], opt_state, [model])
model = eqx.apply_updates(model, updates[0])
print(model.a, model.b)
Optax has a function
multi_transform
, which is nice for using multiple optimizers. See here for a working example.How should Equinox interact with
multi_transform
?Suppose, for instance, I want to freeze the last layer of my network and I want to use optax's
set_to_zero
function. NB: I'm aware of this example, but it's less general than my question because it only addresses the case where you want to freeze a set of parameters (by excluding them from the differentiable set) and doesn't address the case where we want to apply a different optax optimizer to different sets of parameters.My code might look something like:
This results in an error because
param_labels
is a PyTree just like the model and the model has a__call__
method implemented, which makes it callable. So then under the hood optax will check if the mask (i.e., theparam_labels
) is callable. Since it is callable, it ends up callingparam_labels(model_params)
, which isn't what we want.We can't just delete the
__call__
method from theparam_labels
object because it's a frozen dataclass. So it seems like the only way forward is to makeparam_labels
a callable function that somehow labels PyTree nodes appropriately. I'm not enough of a PyTree pro to know if there's an easy way to manipulate PyTrees to easily target the weight matrix of the last layer. Does anyone have a sense for how to do this?