facebookresearch / DiT

Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"
Other
6.44k stars 578 forks source link

Clarification on Zero Initialization in FinalLayer of DiT Model #82

Open denemmy opened 7 months ago

denemmy commented 7 months ago

Hello Facebook Research Team,

I am exploring the DiT as implemented in your repository and came across the weight initialization strategy for the FinalLayer, particularly observed in this section of the code.

The weights for the linear layer in the FinalLayer are initialized to zeros:

nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)

Typically, neural network weights are initialized with non-zero values to break symmetry and ensure diverse feature learning. While I understand the rationale behind zero initialization of modulation weights in other parts of the model, the zero initialization in this linear layer caught my attention.

Is the zero initialization of weights in this non-modulation linear layer intentional, and could you provide any insights into this choice?

Thank you for any information or insights you can provide!

Best regards, Danil.

tanghengjian commented 6 months ago

zero initializtion may help for model's stable and reproducible ?

shy19960518 commented 6 months ago

Same confusion. The most outrageous thing is that the model can still learn well in my experiment. Can someone have an explains. ^ ^

zhaohm14 commented 1 month ago

Hi Danil,

I have the same confusion too. However, although I don't understand how the zero initialization on final_layer.linear benefits, I believe this operation should not cause symmetry problems that hinder training.

The symmetry problem occurs most often in multi-layer networks with hidden nodes. During backpropagation, if all hidden nodes in the same layer share the same values and weights due to identical initialization, it leads to a symmetry problem where the hidden layer effectively functions as a single node.

To avoid the symmetry problem in neural networks, at each layer, either the inputs $I$ or the gradients with respect to the outputs $\frac{\partial L}{\partial O}\$ must not be symmetric. This is because the gradient with respect to the weights is calculated as $\frac{\partial L}{\partial W} = I^T \cdot \frac{\partial L}{\partial O}\$, and asymmetry in either term ensures diverse weight updates.

However, there is no hidden layer in final_layer.linear or adaLN_modulation. Although the outputs and weights might be symmetrical in the first step, the inputs are not symmetrical. This asymmetry in the inputs ensures that the weights are updated differently, thus breaking the symmetry.