microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3.01k stars 202 forks source link

initialization of qkv #68

Closed XintianHan closed 1 year ago

XintianHan commented 1 year ago

In the paper, the authors mentioned that the initialization followed DeepNet but from the code, it's kind of different. Why is there a mismatch?

def reset_parameters(self):
    nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5)
    nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5)
    nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
    nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
    nn.init.xavier_uniform_(self.out_proj.weight)
    nn.init.constant_(self.out_proj.bias, 0.0)
shumingma commented 1 year ago

RetNet uses DeepNet's derivation methods to obtain the initialization for better training stability, instead of directly re-using its derived initialization (on Post-LN transformers), because the initialization depends on the model architecture according to the theory in DeepNet.

XintianHan commented 1 year ago

RetNet uses DeepNet's derivation methods to obtain the initialization for better training stability, instead of directly re-using its derived initialization (on Post-LN transformers), because the initialization depends on the model architecture according to the theory in DeepNet.

Thanks for the quick reply!

"because the initialization depends on the model architecture according to the theory in DeepNet"

Could you elaborate the derivation methods more? How do you get the number 2 ** -2.5 here? Thanks

radarFudan commented 10 months ago

I am also interested in this initialisation scheme. It seems for recurrent models such as S4 and S5, they have different schemes. Do you have any particular explanation or heuristic of this scale?