google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.09k stars 644 forks source link

nnx.initializers.uniform should support custom lower and upper bounds #4127

Closed SamPruden closed 2 months ago

SamPruden commented 2 months ago

This is a tiny petty usability thing, but I just had to write this as an nnx.Conv argument

kernel_init = lambda key, shape, dtype: jax.random.uniform(
  key, shape, dtype,
  minval = -args.conv_init,
  maxval = args.conv_init
) if args.conv_init else None,

Which felt quite silly because it felt like I should be able to do

kernel_init = nnx.initializers.uniform(-args.conv_init, args.conv_init) if args.conv_init else None,

However initializers.uniform only takes a single scale parameter and outputs in [0, scale) for some reason. I would say that where applicable the initializers should act like wrappers around the equivalent jax.random functions and offer the same options.

I would say the same for initializers.normal, however I've just noticed that random.normal doesn't let you choose the mean or stddev. That's quite surprising. I would expect those options to be available in both places.

SamPruden commented 2 months ago

I'm being a little unfair by not golfing the manual case, so I should point out that it can be written a bit more compactly:

kernel_init = lambda *a: random.uniform(*a, -args.conv_init, args.conv_init) if args.conv_init else None,

It's not too bad but it's not quite as clean as it could be.

Whilst I'm on the topic of tiny petty unimportant things about initializers and randomness, it would also be nice if random.uniform could accept an int as a shape instead of having to do (n, ).

cgarciae commented 2 months ago

Hey, we mostly just re-export jax's initializers from jax.nn.initializers for convenience.

SamPruden commented 2 months ago

Hey, we mostly just re-export jax's initializers from jax.nn.initializers for convenience.

Ah sorry, of course! I'll refile there if I can be bothered to annoy them with something so trivial.