Closed AakashKumarNain closed 4 months ago
My approach would be to define a filter spec for that set of parameters then just define the optimizer over that specific set. Similar to https://docs.kidger.site/equinox/tricks/#custom-per-parameter-behaviour I would first try something like:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
class Weight(eqx.Module):
w: jax.Array
def get(self):
return self.w
def __matmul__(self, o):
return self.w @ o
model = eqx.nn.MLP(2, 'scalar', 10, 3, key=jax.random.PRNGKey(0))
is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_weights = lambda m: [x.weight
for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
if is_linear(x)]
weighted_model = eqx.tree_at(get_weights, model, replace_fn=Weight)
weight_optimizer = optax.adamw(0.1)
other_optimizer = optax.sgd(0.1)
w_filter = lambda x: isinstance(x, Weight)
leaf = lambda x: isinstance(x, Weight)
o_filter = lambda x: not isinstance(x, Weight) and eqx.is_array(x)
w_opt_state = weight_optimizer.init(eqx.filter(weighted_model, w_filter, is_leaf=leaf))
o_opt_state = other_optimizer.init(eqx.filter(weighted_model, o_filter, is_leaf=leaf))
grads = eqx.filter_grad(lambda x, y: x(y))(weighted_model, jnp.ones(2))
updates1, _ = weight_optimizer.update(eqx.filter(grads, w_filter, is_leaf=leaf), w_opt_state, weighted_model)
updates2, _ = other_optimizer.update(eqx.filter(grads, o_filter, is_leaf=leaf), o_opt_state, weighted_model)
weighted_model = eqx.apply_updates(weighted_model, updates1)
weighted_model = eqx.apply_updates(weighted_model, updates2)
Thanks for the detailed response @lockwo A few comments though:
Transformer
with 32 TransformerLayers
each of which consists of MHA
, Linear
and LayerNorm
layers. Ideally I should filter my transformer based on the layer type as well as the shapes of the parameters.optax.adamw
comes with a mask
where you can mention the group to which weight decay won't be applied. Link to the documentation for your referencePS: I think there should be a cleaner way of filtering params and distributing them into different groups. I don't have a clear cut answer to this right away, but it should be possible
Your points are generally true I think, I was just highlighting the general framework I would go about doing.
I don't know if I could write a universal thing that would definitely work for your codebase, but this approach is general. If you have a more complicated MVC, could also take a look
Gotcha! Thanks for the clarification. Also, I am not looking for a universal patter but rather a cleaner pattern.
Re complicated MVC: Yeah, let me do that because that would give everyone more clarity on what I am trying to achieve. Give me a a couple of hours
@lockwo here is a better MVC:
class MLP(eqx.Module):
fc1: eqx.nn.Linear
fc2: eqx.nn.Linear
def __init__(self, key, dtype=jnp.bfloat16):
key1, key2 = jax.random.split(key, 2)
self.fc1 = eqx.nn.Linear(32, 64, key=key1, dtype=dtype)
self.fc2 = eqx.nn.Linear(64, 64, key=key2, dtype=dtype)
def __call__(self, x):
pass
class Attention(eqx.Module):
wqkv: eqx.nn.Linear
proj: eqx.nn.Linear
drop: eqx.nn.Dropout
def __init__(self, key, dtype=jnp.bfloat16):
key1, key2 = jax.random.split(key, 2)
self.wqkv = eqx.nn.Linear(64, 3 * 64, key=key1) # 3 for qkv
self.proj = eqx.nn.Linear(64, 64, key=key2)
self.drop = eqx.nn.Dropout()
def __call__(self, x, mask=None):
pass
class TransformerBlock(eqx.Module):
norm_1: eqx.nn.LayerNorm
norm_2: eqx.nn.LayerNorm
attn: Attention
mlp: MLP
def __init__(self, key, dtype=jnp.bfloat16):
key1, key2 = jax.random.split(key, 2)
self.norm_1 = eqx.nn.LayerNorm(64)
self.attn = Attention(key=key1, dtype=dtype)
self.norm_2 = eqx.nn.LayerNorm(64)
self.mlp = MLP(key=key2, dtype=dtype)
def __call__(self, x, mask=None):
pass
class Transformer(eqx.Module):
pos_embed: eqx.nn.Embedding
tf_blocks: TransformerBlock
norm: eqx.nn.LayerNorm
def __init__(self, key, num_layers=2, dtype=jnp.bfloat16):
keys = jax.random.split(key, num_layers + 3)
key1, key2, key3, tf_keys = keys[0], keys[1], keys[2], keys[3:]
self.tf_blocks = [TransformerBlock(tf_keys[i]) for i in range(num_layers)]
self.norm = eqx.nn.LayerNorm(64)
self.pos_embed = eqx.nn.Embedding(64, 64, key=key1)
def __call__(self, x):
pass
model = Transformer(jax.random.PRNGKey(1))
Here, I would like to apply weight decay to all the weight
param except for the normalization layers, and the biases
So in that case, it would be like
exclude = lambda x: isinstance(x, eqx.nn.LayerNorm)
leaf = lambda x: hasattr(x, "weight") and not exclude(x)
get_weights = lambda m: [x.weight for x in jax.tree_util.tree_leaves(m, is_leaf=leaf) if isinstance(x, eqx.Module)]
weighted_model = eqx.tree_at(get_weights, model, replace_fn=Weight)
and you don't need two optimizers, because now you could filter for Weight/use the mask
Perfect. Thanks a lot. Let me try this, will report my findings. BTW did you mean leaf
instead of Weight
in the replace_fn
@lockwo ?
Looks like I still can't get it working (either I am misunderstanding something here or it is just too much of work to achieve this). Here is what I want but doesn't seem to be working for me:
def set_weights(weight, bias):
# Set both the weights and bias to None
# if they are to be masked from weight decay
return None, None
def is_layer(x):
return isinstance(x, eqx.Module)
def get_weights(model):
weights = []
biases = []
for x in jax.tree.leaves(model, is_leaf=is_layer):
# For each layer, check the shape of the parameters.
# If any param is 1D, store it and return to replace it with None
if hasattr(x, "weight") and x.weight.ndim < 2:
weights.append(x.weight)
if hasattr(x, "bias"):
biases.append(x.bias)
return weights, biases
masked_params = eqx.tree_at(get_weights, model, set_weights)
Would appreciate any help on this
Any suggestions? @patrick-kidger
I've run into this issue myself. Would it be possible to, say, extend eqx.field
to add some sort of tagging system to the field metadata, which we could then use in different filters?
If I understand correctly you want a boolean mask (to pass to optax.adamw(..., mask=...)
), indicating which parameters you want to apply weight decay to?
From your description it sounds like you want to apply weight decay to (a) the weight and biases of all linear layers, and (b) just the biases of all LayerNorm layers? If so then that would correspond to
params = ...
def is_layer(x):
return isinstance(x, eqx.nn.Linear) or isinstance(x, eqx.nn.LayerNorm)
def set_mask(x):
if isinstance(x, eqx.nn.Linear):
return jtu.tree_map(lambda _: True, x)
elif isinstance(x, eqx.nn.LayerNorm):
mask = jtu.tree_map(lambda _: False, x)
mask = eqx.tree_at(lambda m: m.bias, mask, True)
return mask
else:
return jtu.tree_map(lambda _: False, x)
mask = jtu.tree_map(set_mask, params, is_leaf=is_layer)
Note that here, params
is whatever you have asked Optax to optimize. Typically this is just the parameters of your model, as obtained by something like params = eqx.filter(model, eqx.is_array)
, see also this FAQ entry.
Also note how I'm not using hasattr
, or checking ndim
, or anything like that. Just isinstance
checks! I'd recommend this approach as usually being much more reliable.
The follow-up question on extending eqx.field
. This is an idea that people like to suggest every now and again, but I'm afraid it doesn't work in general. (You can probably find some past discussions if you go looking back through this issue tracker.) There are several issues, the two most notable being:
field
is something owned by the surrounding class, but when you tree-map you get the individual parameters.eqx.Module
is just a PyTree. Nothing more or less. Keeping this rule makes it easy to reason about Equinox code. Changing that to add special behaviour -- here some field metadata -- seems like a step in the wrong direction.But that said, you are free to add your own metadata to eqx.field
! Just like dataclasses.field
. So if you want to construct something that works in your individual use-case then you are free to do so :)
Thanks @patrick-kidger for the detailed info. This is very helpful. One last query on this. In the above code, you have defined the set_mask
function to set the mask for the weights and biases, but you haven't used it in the above code. Did you miss a line or two in the above code by any chance?
Typo'd. Fixed!
This still doesn't work. Here is a MWE that you can copy-paste and try:
class MLP(eqx.Module):
fc1: eqx.nn.Linear
fc2: eqx.nn.Linear
def __init__(self, key, dtype=jnp.bfloat16):
key1, key2 = jax.random.split(key, 2)
self.fc1 = eqx.nn.Linear(32, 64, key=key1, dtype=dtype)
self.fc2 = eqx.nn.Linear(64, 64, key=key2, dtype=dtype)
def __call__(self, x):
pass
class Attention(eqx.Module):
wqkv: eqx.nn.Linear
proj: eqx.nn.Linear
drop: eqx.nn.Dropout
def __init__(self, key, dtype=jnp.bfloat16):
key1, key2 = jax.random.split(key, 2)
self.wqkv = eqx.nn.Linear(64, 3 * 64, key=key1) # 3 for qkv
self.proj = eqx.nn.Linear(64, 64, key=key2)
self.drop = eqx.nn.Dropout()
def __call__(self, x, mask=None):
pass
class TransformerBlock(eqx.Module):
norm_1: eqx.nn.LayerNorm
norm_2: eqx.nn.LayerNorm
attn: Attention
mlp: MLP
def __init__(self, key, dtype=jnp.bfloat16):
key1, key2 = jax.random.split(key, 2)
self.norm_1 = eqx.nn.LayerNorm(64)
self.attn = Attention(key=key1, dtype=dtype)
self.norm_2 = eqx.nn.LayerNorm(64)
self.mlp = MLP(key=key2, dtype=dtype)
def __call__(self, x, mask=None):
pass
class Transformer(eqx.Module):
pos_embed: eqx.nn.Embedding
tf_blocks: TransformerBlock
norm: eqx.nn.LayerNorm
def __init__(self, key, num_layers=2, dtype=jnp.bfloat16):
keys = jax.random.split(key, num_layers + 3)
key1, key2, key3, tf_keys = keys[0], keys[1], keys[2], keys[3:]
self.tf_blocks = [TransformerBlock(tf_keys[i]) for i in range(num_layers)]
self.norm = eqx.nn.LayerNorm(64)
self.pos_embed = eqx.nn.Embedding(64, 64, key=key1)
def __call__(self, x, y, mask=None):
pos_embed = jax.vmap(self.pos_embed)(y)
def is_layer(x):
return isinstance(x, eqx.nn.Linear) or isinstance(x, eqx.nn.LayerNorm)
def set_mask(x):
if isinstance(x, eqx.nn.Linear):
return jtu.tree_map(lambda _: True, x)
elif isinstance(x, eqx.nn.LayerNorm):
mask = jtu.tree_map(lambda _: False, x)
mask = eqx.tree_at(lambda m: m.bias, mask, True)
return mask
else:
return jtu.tree_map(lambda _: False, x)
model = Transformer(jax.random.PRNGKey(1))
params = eqx.filter(model, eqx.is_array)
mask = jtu.tree_map(set_mask, params, is_leaf=is_layer)
optim = optax.adamw(learning_rate=1e-4, mask=mask)
opt_state = optim.init(params)
---> 83 opt_state = optim.init(params)
File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/combine.py:64, in chain.<locals>.init_fn(params)
63 def init_fn(params):
---> 64 return tuple(fn(params) for fn in init_fns)
File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/combine.py:64, in <genexpr>(.0)
63 def init_fn(params):
---> 64 return tuple(fn(params) for fn in init_fns)
File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/wrappers.py:544, in masked.<locals>.init_fn(params)
541 if isinstance(params, _state_utils._ParamsPlaceholder): # pylint:disable=protected-access
542 return MaskedState(inner_state=inner.init(params))
--> 544 mask_tree = mask(params) if callable(mask) else mask
545 masked_params = mask_pytree(params, mask_tree)
546 return MaskedState(inner_state=inner.init(masked_params))
TypeError: Transformer.__call__() missing 1 required positional argument: 'y'
Probably a known Optax oddity: they check to see if certain values are callable
, and if they are, they call them. This takes precedence over the pytree-ness of an object.
Wrap your masks/gradients/etc. into a length-1 list to defeat the callable
check.
Should I raise an issue on optax repo regarding this? This seems very limiting and unnecessary IMHO
Wrap your masks/gradients/etc. into a length-1 list to defeat the callable check.
This didn't work btw
@patrick-kidger IMO it would be good if you can take a look at the issue and the corresponding PR raised in Optax for this issue. They fixed it, but it broke the updates
workflow of adamw
w/o mask. I am suggesting this because there isn't going to be a new release of optax anytime soon, so it will be worth fixing them now
What are you proposing that Equinox does differently?
What are you proposing that Equinox does differently?
Sorry for being not very clear in my earlier comment. I am not suggesting anything to change on the Equinox side. I am asking if you can share some thoughts in thread(issue) I opened on the Optax side.
Apologies if this has been asked before but I couldn't find any example that demonstrates this in a simple manner.
I have a model built in Equinox. Now, I want to use the
AdamW
optimizer where:weight_decay
to certain parameters (i.e. divide the pytree into two groups and apply weight decay to one group). For example, what if I want to apply weight decay only weights (except for normalization layer)?