Open tomsch420 opened 1 month ago
There is a risk to the suggested approach that should at least be highlighted in the docs: the parameters may still be punished by regularization.
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
from optax import adamw
class Model(eqx.Module):
buffer: Array
param: Array
def __call__(self, x):
return self.param * x + jax.lax.stop_gradient(self.buffer)
@eqx.filter_value_and_grad
def loss(model, x):
return model(x)
model = Model(jnp.ones(()), jnp.ones(()))
loss, grad = loss(model, 2)
optimizer = adamw(1e-1) # Optimizer with regularization
opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
updates, opt_state = optimizer.update(grad, opt_state, eqx.filter(model, eqx.is_array))
model = eqx.apply_updates(model, updates)
assert model.buffer == jnp.ones(()) # Fails!
Unless I am missing a downside, the approach I think should be recommended is to use a wrapper class (NonTrainable
) to wrap non-trainable nodes, and partitioning parameters e.g. with:
params, static = eqx.partition(
model,
eqx.is_inexact_array,
is_leaf=lambda leaf: isinstance(leaf, NonTrainable),
)
Ah! That really isn't very good, you're right.
Hmm, I'm trying to figure out if there's a way to handle this ergonomically. The best I can come up with is to wrap the Optax calls (like we already do for eqx.apply_updates
) with something that respects such a Nontrainable
wrapper. This is just such an easy footgun!
FWIW I've landed on the optax wrapper approach. I have a trainable/non_trainable mask that I create early on and partition that way. I don't even bother with stop_grad most of the time and pray that XLA does the DCE for me (it seems to).
For things that are really constants (e.g. rotary embeddings) I just materialize those in the kernel with ensure_compile_time_eval
Ah, nice! Okay, I think I'm convinced.
I'd be happy to take a PR implementing this, then.
Greetings!
I got custom Layers in equinox that look approximately like this.
I now want to exclude ProductLayer.edges from the parameters of a model since they cannot be adjusted by gradient descent. Fruthermore, SumLayer.log_weights.indices can also not be adjusted. The ContinuousLayerWithFiniteSupport.interval can also not be adjusted using gradient descent. How can i best filter these out for the eqx.partition method?