patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

How to mark parameters as trainable or not? #866

Open tomsch420 opened 1 month ago

tomsch420 commented 1 month ago

Greetings!

I got custom Layers in equinox that look approximately like this.

class ProductLayer(InnerLayer):

    child_layers: List[Union[SumLayer, InputLayer]]
    edges: BCOO

class SumLayer(InnerLayer):

    log_weights: List[BCOO]
    child_layers: Union[List[[ProductLayer]], List[InputLayer]]

class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC):
    interval: jax.Array

I now want to exclude ProductLayer.edges from the parameters of a model since they cannot be adjusted by gradient descent. Fruthermore, SumLayer.log_weights.indices can also not be adjusted. The ContinuousLayerWithFiniteSupport.interval can also not be adjusted using gradient descent. How can i best filter these out for the eqx.partition method?

patrick-kidger commented 1 month ago

See the FAQ: https://docs.kidger.site/equinox/faq/#how-to-mark-arrays-as-non-trainable-like-pytorchs-buffers

danielward27 commented 2 weeks ago

There is a risk to the suggested approach that should at least be highlighted in the docs: the parameters may still be punished by regularization.

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
from optax import adamw

class Model(eqx.Module):
    buffer: Array
    param: Array

    def __call__(self, x):
        return self.param * x + jax.lax.stop_gradient(self.buffer)

@eqx.filter_value_and_grad
def loss(model, x):
    return model(x)

model = Model(jnp.ones(()), jnp.ones(()))
loss, grad = loss(model, 2)
optimizer = adamw(1e-1)  # Optimizer with regularization
opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
updates, opt_state = optimizer.update(grad, opt_state, eqx.filter(model, eqx.is_array))
model = eqx.apply_updates(model, updates)
assert model.buffer == jnp.ones(())  # Fails!

Unless I am missing a downside, the approach I think should be recommended is to use a wrapper class (NonTrainable) to wrap non-trainable nodes, and partitioning parameters e.g. with:

params, static = eqx.partition(
        model,
        eqx.is_inexact_array,
        is_leaf=lambda leaf: isinstance(leaf, NonTrainable),
    )
patrick-kidger commented 2 weeks ago

Ah! That really isn't very good, you're right.

Hmm, I'm trying to figure out if there's a way to handle this ergonomically. The best I can come up with is to wrap the Optax calls (like we already do for eqx.apply_updates) with something that respects such a Nontrainable wrapper. This is just such an easy footgun!

dlwh commented 2 weeks ago

FWIW I've landed on the optax wrapper approach. I have a trainable/non_trainable mask that I create early on and partition that way. I don't even bother with stop_grad most of the time and pray that XLA does the DCE for me (it seems to).

For things that are really constants (e.g. rotary embeddings) I just materialize those in the kernel with ensure_compile_time_eval

patrick-kidger commented 2 weeks ago

Ah, nice! Okay, I think I'm convinced.

I'd be happy to take a PR implementing this, then.