patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.02k stars 135 forks source link

Why uniform bias initializer? #179

Open jenkspt opened 2 years ago

jenkspt commented 2 years ago

I noticed that the bias parameters in linear and conv modules use a uniform initializer. Is there a good justification for this? I noticed that PyTorch does this to. I was expecting a zero initializer, which is also what Flax uses.

patrick-kidger commented 2 years ago

The justification is really just "that is what PyTorch does". (I copied the initialisers over when making Equinox.)

If you want a zero bias then this can be done by replacing the leaf:

linear = eqx.nn.Linear(...)
linear = eqx.tree_at(lambda l: l.bias, linear, jnp.zeros_like(linear.bias))
angusturner commented 2 years ago

Maybe this deserves a separate issues / feature request, but it was non-obvious to me as a first-time user (coming from Haiku) how to overwrite the default initialization scheme in nn.Linear.

I think it would be cool if there was a w_init argument (similar to Haiku), or perhaps just a comment in the docs explaining what is the most idiomatic approach to do this.

For what its worth, I find the eqx.tree_at approach to be a bit "too clever". Like, its not immediately obvious how this works to a new user. Maybe that's just me though.

As an aside, I am really enjoying Equinox so far and appreciate all the work you have put in to it!

patrick-kidger commented 2 years ago

Thanks, I'm glad you like it!

Regarding adding this to the docs: yep, this is on the horizon (#185).

Regarding making this an additional argument: so far I've been resisting this as it adds quite a lot of complexity. It would add a lot of arguments that I think most folks don't use (PyTorch doesn't have this). It would necessitate adding a new eqx.nn.init namespace to hold the existing initialisers. Both of the above clutter the documentation.

But that isn't a strong feeling on my part. This request does come up every now and again, and I'd be happy to be persuaded otherwise?

(cc @typedfemale as we were talking about this recently)

jenkspt commented 2 years ago

Is there a good way to change all bias initializers to zero?

angusturner commented 2 years ago

That is a fair point!

It is true that PyTorch doesn't take a w_init, but because stuff is mutable in PyTorch its very easy to just set linear.bias = ... (or do an in-place modify or whatever).

But I guess this is just teething pains of adjusting to immutability / PyTrees.

I will give this more thought as I adjust to the eqx.tree_at model surgery approach!

paganpasta commented 2 years ago

Is there a good way to change all bias initializers to zero?

+1 for this. Right now I have a hacky method relying on recursive getattr to first get a sequence of target attribute strings and then pass them to eqx.tree_at for replacement. Would like to know of a cleaner method to achieve this.

patrick-kidger commented 2 years ago

So I can think of two non-hacky methods for setting all biases to zero.

Option 1:

def linear(*args, **kwargs):
    out = eqx.nn.Linear(*args, **kwargs)
    out = eqx.tree_at(lambda l: l.bias, out, replace_fn=jnp.zeros_like)
    return out

and then using linear (or an analogous conv2d etc.) everywhere you were previously using eqx.nn.Linear.

Option 2:

model = ...
has_bias = lambda x: hasattr(x, "bias")
where = lambda m: [x.bias for x in jax.tree_util.tree_leaves(m, is_leaf=has_bias) if has_bias(x)]
model = eqx.tree_at(where, model, replace_fn=jnp.zeros_like)

FWIW I think I am hearing consensus in favour of this proposal! So perhaps let's just add it :D

The idea would be to:

This is a fair amount of work so this probably isn't something I'll find time to implement in the near future. If anyone involved in this thread feels suitably motivated, then I'd be very happy to accept a PR on this.

jenkspt commented 2 years ago

I'm not pushing to add init args. I agree that it adds unnecessary complexity (and work :sweat_smile:)

paganpasta commented 1 year ago

Option 2 covers my use-cases with a minor modification; if has_bias(x) and x.bias is not None. Thanks!

jloganolson commented 1 year ago

Does either option work with nn.MLP or would I need to do more surgery? It's unclear to me how I would use tree_at to access each linear layer but maybe my pytree skills are just weak...

patrick-kidger commented 1 year ago

Approach 1: the above "Option 2" should work out-of-the-box: it will detect each linear layer as these all have a bias attribute.

Approach 2: if you wanted, you could also switch out has_bias with is_linear = lambda x: isinstance(x, eqx.nn.Linear), and then do exactly as before. As eqx.nn.MLP uses linear layers internally, and all its linear layers are the ones with biases, then this will work equally well.

Approach 3: if you're working with an MLP specifically then you could also just list out all its linear layers manually:

mlp = eqx.nn.MLP(...)
where = lambda m: [lin.bias for lin in m.layers]
mlp = tree_at(where, mlp, replace_fn=jnp.zeros_like)

Any one of these approaches is equally fine.

Ultimately I think this kind of model surgery is one of the greatest strengths of Equinox. It takes a little getting used to, but variations on the above pattern allow you to perform almost any kind of adjustment to your model.

adam-coogan commented 1 year ago

I've been trying to experiment with different initialization schemes for MLPs and came across this issue. Is there a simple way to use tree_at to apply e.g. Lecun initialization to all weights and biases?

patrick-kidger commented 1 year ago

Something like:

import equinox as eqx
import jax.random as jr
import jax.tree_util as jtu
from jaxtyping import Array, Float

key = jr.PRNGKey(...)
model = ... # your model

def lecun_init(weight: Float[Array, "out in"], key: jr.PRNGKey) -> Float[Array, "out in"]:
  out, in_ = weight.shape
  stddev = math.sqrt(1 / in_)
  return stddev * jr.truncated_normal(key, lower=-2, upper=2)

is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_weights = lambda m: [x.weight for x in jtu.tree_leaves(m, is_leaf=is_linear) if is_linear(x)]
weights = get_weights(model)
new_weights = [lecun_init(weight, subkey) for weight, subkey in zip(weights, jr.split(key, len(weights)))]
new_model = eqx.tree_at(get_weights, model, new_weights)

(The jaxtyping annotations are just a nice-to-have, they don't affect runtime.)

vspinu commented 1 year ago

Add extra arguments weight_init=..., bias_init=... to eqx.nn.{Linear, Conv, MultiheadAttention, ...}.

May I suggest having instead one init dict argument. Then passing initializers to nested sublayers could follow the same hierarchy as sublayers.

cloudhan commented 2 months ago

This should have been in #622 but I put it here as this issue is still open.

in jax world, it is all about transformation, with eqx has already done model = some_map(model), then the layer init should be layer = eqx.init(layer, key, weight=default_init, bias=default_init)

def init(pytree, key, **kwargs):
  for attr, init_method in kwargs.items():
    key, init_key = jrandom.split(key)
    pytree = eqx.tree_at(
        lambda t: getattr(t, attr, None),
        pytree,
        replace_fn=functools.partial(init_method, key=init_key),
    )
  return pytree

class MyModule(nn.Module):
  proj: nn.Linear

  def __init__(self, *, key=None):
    self.proj = eqx.init(nn.Linear(in_feat, out_feat), key, weight=xavier_init, bias=optional_default_init)  # optional_default_init accepts a None

But how to init large models (might rely on distributed rng behavior, tho) remain open.