patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Adding a kernel constraint #809

Closed the-wind-is-rising closed 3 months ago

the-wind-is-rising commented 3 months ago

Hi all,

Have been using equinox a lot recently and am very much enjoying it. Thank you for your hard work.

I'm currently translating a model from tf2/keras, and wanted to make sure I was properly re-implementing the max_norm constraint for certain layers.

The tensorflow implementation looks like this:

def __call__(self, w):
    w = backend.convert_to_tensor(w)
    norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True))
    desired = ops.clip(norms, 0, self.max_value)
    return w * (desired / (backend.epsilon() + norms))

Rather than reimplementing the entire constraint API, I just want to attach max_norm to one convolutional layer. Here is what I have:

class ConstrainedMaxNormConv2d(eqx.nn.Conv2d):
    max_norm: float

    def __init__(self, *args, max_norm: float = 1.0, **kwargs):
        self.max_norm = max_norm
        super().__init__(*args, **kwargs)

    def __call__(self, x: Array, *, eps: float = 1e-7):
        norms = jnp.sqrt(jnp.sum(jnp.square(self.weight), keepdims=True))
        desired = jnp.clip(norms, 0, self.max_norm)
        self.weight.at[:].set(self.weight * (desired / (eps + norms)))
        return super().__call__(x)

My question: is this the best way to implement a kernel constraint in equinox? Initial tests are promising, but I'd love some feedback from more experienced developers. Thanks.

lockwo commented 3 months ago

My initial reaction would be to make a layer wrapper (sort of like spectral norm), that explicitly defines what attribute you are operating with then just redefines that in the call. Something along the lines of

import jax
import jax.numpy as jnp
import equinox as eqx

class MaxNormConstraint(eqx.Module):
    layer: eqx.Module
    weight_name: str = eqx.field(static=True)
    max_norm: int = eqx.field(static=True)

    def __init__(
        self,
        layer,
        weight_name,
        max_norm,
    ):
        self.layer = layer
        self.weight_name = weight_name
        self.max_norm = max_norm

    @jax.named_scope("eqx.nn.MaxNormConstraint")
    def __call__(
        self,
        x,
        *,
        key=None,
        inference=None,
    ):
        eps = 1e-8
        weight = getattr(self.layer, self.weight_name)
        norms = jnp.sqrt(jnp.sum(jnp.square(weight), keepdims=True))
        desired = jnp.clip(norms, 0, self.max_norm)
        new_weight = weight * (desired / (eps + norms))
        layer = eqx.tree_at(
            lambda l: getattr(l, self.weight_name), self.layer, new_weight
        )
        return layer(x)

linear = eqx.nn.Linear(10, 20, key=jax.random.key(0))
contrained_linear = MaxNormConstraint(linear, "weight", 1.0)
contrained_linear(jnp.arange(10))

PS self.weight.at[:].set(self.weight * (desired / (eps + norms))) is an out of place function and has no effect there.

the-wind-is-rising commented 3 months ago

This worked great, thanks