n2cholas / jax-resnet

Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
https://pypi.org/project/jax-resnet/
MIT License
103 stars 8 forks source link

Initialization is incorrect #3

Closed n2cholas closed 3 years ago

n2cholas commented 3 years ago

Currently, all layers use the default Flax initialization. However, each paper uses a different strategy:

  1. ResNet, WideResNet, ResNeXt use Kaiming Normal
  2. ResNet-D uses Xavier Uniform
  3. ResNeSt say they use Kaiming Normal, but the code uses the PyTorch default which is Kaiming Uniform with a=sqrt(5).

There are a few options going forward:

  1. Set all the models to use Kaiming [Normal or Uniform], which has been shown to work best with ReLU activations. With this decision, we'll probably deviate from the torch default gain (which is for LeakyReLU) to a gain that is suitable to vanilla ReLU.
  2. Set all the models to the initialisation provided in their respective papers.
  3. Provide no default, force users to select one, but provide suggestions for suitable candidates in the docstring.