Closed JiahaoYao closed 3 years ago
Very interesting points and thanks for pointing out the differences between PyTorch and Stax. I've never checked those! I think of Stax as an experimental extension of JAX, so I'd consider checking Flax instead.
Flax (google/flax) is what I've been recommended to use for all things neural networks with JAX.
I did some quick digging and, by default, Flax appears to use LeCun Normal (jax.nn.initializers.lecun_normal()
) for weights and zeros for bias: https://flax.readthedocs.io/en/latest/_modules/flax/nn/linear.html#Dense
default_kernel_init = initializers.lecun_normal()
...
class Dense(base.Module):
...
def apply(self, ...kernel_init=default_kernel_init,bias_init=initializers.zeros):
...
class Conv(base.Module):
...
def apply(self, ...kernel_init=default_kernel_init,bias_init=initializers.zeros):
...
And the docs also show that with Flax you can set the initializer in kernel_init
(weights) and bias_init
(bias) parameters to something other than the default value (at least in theory):
So, that probably means that you'd be able to choose any of the following initializers from the core JAX library's neural net initializer library jax.nn.initializers
: https://jax.readthedocs.io/en/latest/jax.nn.initializers.html
jax.nn.initializers.zeros()
jax.nn.initializers.ones()
jax.nn.initializers.uniform()
jax.nn.initializers.normal()
jax.nn.initializers.variance_scaling()
jax.nn.initializers.glorot_uniform()
jax.nn.initializers.glorot_normal()
jax.nn.initializers.lecun_uniform()
jax.nn.initializers.lecun_normal()
jax.nn.initializers.he_uniform()
jax.nn.initializers.he_normal()
So many to choose from 👍 (is Xavier available though? https://www.deeplearning.ai/ai-notes/initialization/)
cc @dynamicwebpaige (DM UX)
Hi @8bitmp3 , thanks for your feedback!
Xavier is available here: https://github.com/google/jax/blob/master/jax/_src/nn/initializers.py#L72-L77
for the bias initialization, the API is not perfect. Please look at the colab: https://colab.research.google.com/drive/1I6X8_laXjtBq9y5Bri16Jje3npCg_ai6?usp=sharing
No worries @JiahaoYao
Xavier is available here: https://github.com/google/jax/blob/master/jax/_src/nn/initializers.py#L72-L77
Ah yes, thanks - I forgot about the different names.
for the bias initialization, the API is not perfect. Please look at the colab:
I think you're still using jax.experimental.stax
. I'd recommend giving Flax a go: https://github.com/google/flax and docs https://flax.readthedocs.io/en/latest/
Update: try flax.linen
instead of flax.nn
: e.g. https://github.com/google/flax/blob/master/flax/linen/linear.py
Cool, will do!
@JiahaoYao if flax
isn't enough for you, check out the haiku
library (also built on JAX) by DeepMind. For example, you can initialize params with truncated Normal hk.initializers.TruncatedNormal()
(example) to mitigate the ☠️ neuron issue.
Thanks, @8bitmp3 so much! I will close this issue and might open if I have further questions. Thank you again!
Hi, the initialization is different the default for jax is https://github.com/google/jax/blob/master/jax/experimental/stax.py#L47 glorot normal for weight and normal for bias pytorch uses kaiming_uniform for weight and lecun uniform for bias https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L87-L91
This failed for the bias initialization, because the shape of bias term is
(d_out, )
and https://github.com/google/jax/blob/master/jax/_src/nn/initializers.py#L50-L55 reports the error in finding thefan-in
.Also, for the conv layers, their initialization is the same as dense layer https://github.com/google/jax/blob/master/jax/experimental/stax.py#L61-L67 and https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L110-L115
One more thing, for the kaiming initialization pytorch has computed the gain https://github.com/pytorch/pytorch/blob/780f854135c9aa65ffa834ea6c37b914ec774cf4/torch/nn/init.py#L380-L381 which is not present in the current jax.