Closed the-wind-is-rising closed 3 months ago
My initial reaction would be to make a layer wrapper (sort of like spectral norm), that explicitly defines what attribute you are operating with then just redefines that in the call. Something along the lines of
import jax
import jax.numpy as jnp
import equinox as eqx
class MaxNormConstraint(eqx.Module):
layer: eqx.Module
weight_name: str = eqx.field(static=True)
max_norm: int = eqx.field(static=True)
def __init__(
self,
layer,
weight_name,
max_norm,
):
self.layer = layer
self.weight_name = weight_name
self.max_norm = max_norm
@jax.named_scope("eqx.nn.MaxNormConstraint")
def __call__(
self,
x,
*,
key=None,
inference=None,
):
eps = 1e-8
weight = getattr(self.layer, self.weight_name)
norms = jnp.sqrt(jnp.sum(jnp.square(weight), keepdims=True))
desired = jnp.clip(norms, 0, self.max_norm)
new_weight = weight * (desired / (eps + norms))
layer = eqx.tree_at(
lambda l: getattr(l, self.weight_name), self.layer, new_weight
)
return layer(x)
linear = eqx.nn.Linear(10, 20, key=jax.random.key(0))
contrained_linear = MaxNormConstraint(linear, "weight", 1.0)
contrained_linear(jnp.arange(10))
PS self.weight.at[:].set(self.weight * (desired / (eps + norms)))
is an out of place function and has no effect there.
This worked great, thanks
Hi all,
Have been using equinox a lot recently and am very much enjoying it. Thank you for your hard work.
I'm currently translating a model from tf2/keras, and wanted to make sure I was properly re-implementing the
max_norm
constraint for certain layers.The tensorflow implementation looks like this:
Rather than reimplementing the entire constraint API, I just want to attach
max_norm
to one convolutional layer. Here is what I have:My question: is this the best way to implement a kernel constraint in equinox? Initial tests are promising, but I'd love some feedback from more experienced developers. Thanks.