google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.77k stars 2.72k forks source link

The neural network initialization of Jax is different from Pytorch #4862

Closed JiahaoYao closed 3 years ago

JiahaoYao commented 3 years ago

Hi, the initialization is different the default for jax is https://github.com/google/jax/blob/master/jax/experimental/stax.py#L47 glorot normal for weight and normal for bias pytorch uses kaiming_uniform for weight and lecun uniform for bias https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L87-L91

Dense(n_actions, W_init=jax.nn.initializers.kaiming_uniform(), b_init= jax.nn.initializers.lecun_uniform() )

This failed for the bias initialization, because the shape of bias term is (d_out, ) and https://github.com/google/jax/blob/master/jax/_src/nn/initializers.py#L50-L55 reports the error in finding the fan-in.

Also, for the conv layers, their initialization is the same as dense layer https://github.com/google/jax/blob/master/jax/experimental/stax.py#L61-L67 and https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L110-L115

One more thing, for the kaiming initialization pytorch has computed the gain https://github.com/pytorch/pytorch/blob/780f854135c9aa65ffa834ea6c37b914ec774cf4/torch/nn/init.py#L380-L381 which is not present in the current jax.

8bitmp3 commented 3 years ago

Very interesting points and thanks for pointing out the differences between PyTorch and Stax. I've never checked those! I think of Stax as an experimental extension of JAX, so I'd consider checking Flax instead.

Flax (google/flax) is what I've been recommended to use for all things neural networks with JAX.

I did some quick digging and, by default, Flax appears to use LeCun Normal (jax.nn.initializers.lecun_normal()) for weights and zeros for bias: https://flax.readthedocs.io/en/latest/_modules/flax/nn/linear.html#Dense

default_kernel_init = initializers.lecun_normal()
...
class Dense(base.Module):
  ...
  def apply(self, ...kernel_init=default_kernel_init,bias_init=initializers.zeros):
  ...

class Conv(base.Module):
  ...
  def apply(self, ...kernel_init=default_kernel_init,bias_init=initializers.zeros):
  ...

And the docs also show that with Flax you can set the initializer in kernel_init (weights) and bias_init (bias) parameters to something other than the default value (at least in theory):

So, that probably means that you'd be able to choose any of the following initializers from the core JAX library's neural net initializer library jax.nn.initializers: https://jax.readthedocs.io/en/latest/jax.nn.initializers.html

So many to choose from 👍 (is Xavier available though? https://www.deeplearning.ai/ai-notes/initialization/)

cc @dynamicwebpaige (DM UX)

JiahaoYao commented 3 years ago

Hi @8bitmp3 , thanks for your feedback!

JiahaoYao commented 3 years ago

Xavier is available here: https://github.com/google/jax/blob/master/jax/_src/nn/initializers.py#L72-L77

JiahaoYao commented 3 years ago

for the bias initialization, the API is not perfect. Please look at the colab: https://colab.research.google.com/drive/1I6X8_laXjtBq9y5Bri16Jje3npCg_ai6?usp=sharing

8bitmp3 commented 3 years ago

No worries @JiahaoYao

Xavier is available here: https://github.com/google/jax/blob/master/jax/_src/nn/initializers.py#L72-L77

Ah yes, thanks - I forgot about the different names.

for the bias initialization, the API is not perfect. Please look at the colab:

I think you're still using jax.experimental.stax. I'd recommend giving Flax a go: https://github.com/google/flax and docs https://flax.readthedocs.io/en/latest/

Update: try flax.linen instead of flax.nn: e.g. https://github.com/google/flax/blob/master/flax/linen/linear.py

JiahaoYao commented 3 years ago

Cool, will do!

8bitmp3 commented 3 years ago

@JiahaoYao if flax isn't enough for you, check out the haiku library (also built on JAX) by DeepMind. For example, you can initialize params with truncated Normal hk.initializers.TruncatedNormal() (example) to mitigate the ☠️ neuron issue.

JiahaoYao commented 3 years ago

Thanks, @8bitmp3 so much! I will close this issue and might open if I have further questions. Thank you again!