Closed SamPruden closed 2 months ago
I'm being a little unfair by not golfing the manual case, so I should point out that it can be written a bit more compactly:
kernel_init = lambda *a: random.uniform(*a, -args.conv_init, args.conv_init) if args.conv_init else None,
It's not too bad but it's not quite as clean as it could be.
Whilst I'm on the topic of tiny petty unimportant things about initializers and randomness, it would also be nice if random.uniform
could accept an int
as a shape instead of having to do (n, )
.
Hey, we mostly just re-export jax's initializers from jax.nn.initializers for convenience.
Hey, we mostly just re-export jax's initializers from jax.nn.initializers for convenience.
Ah sorry, of course! I'll refile there if I can be bothered to annoy them with something so trivial.
This is a tiny petty usability thing, but I just had to write this as an
nnx.Conv
argumentWhich felt quite silly because it felt like I should be able to do
However
initializers.uniform
only takes a singlescale
parameter and outputs in[0, scale)
for some reason. I would say that where applicable the initializers should act like wrappers around the equivalentjax.random
functions and offer the same options.I would say the same for
initializers.normal
, however I've just noticed thatrandom.normal
doesn't let you choose the mean or stddev. That's quite surprising. I would expect those options to be available in both places.