majumderb / rezero

Official PyTorch Repo for "ReZero is All You Need: Fast Convergence at Large Depth"
https://arxiv.org/pdf/2003.04887.pdf
MIT License
405 stars 52 forks source link

Relationship between ReZero and Zero gamma trick #6

Closed hukkai closed 4 years ago

hukkai commented 4 years ago

Hello! Thanks for your interesting work and useful codes.

I have one small question. In table 1 of the paper, the formulation of Residual Network + Pre-Norm is ). From my understanding, the corresponding formulation of Residual Network + Post-Norm should be ) which is also the real practice in ResNet. But the paper referred to a different formulation. Is this a typo or do I understand something wrong?

In this ) formulation, a trick called zero gamma trick (setting gamma=0 for every batch normalization going back to the main branch) is commonly used [1,2]. Similar invariant Fixup Initialization [3] also benefits from this idea and shows the ability to train very deep neural network. The trick is used by both PyTorch code link and TensorFlow code link ResNet implementations. What is the relationship between ReZero and Zero gamma trick? Thanks!

[1] Goyal et al. Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour. [2] He et al. Bag of Tricks for Image Classification with Convolutional Neural Networks. [3] Zhang et al. Fixup Initialization: Residual Learning Without Normalization.

majumderb commented 4 years ago

We tried to use "zero gamma trick", which initializes the last batchnorm layer as zero, which improves final validation accuracy. ReZero achieves the same final accuracy, but faster by about 2x.

hukkai commented 4 years ago

Thanks for your response! The difference in convergence speed is quite interesting. Look forward to some ablation studies for the two cases (vector v.s. scalar) or some explanations why they converge differently.