pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.28k stars 115 forks source link

about reference of weight init according to layer depth or layer id #375

Closed SeunghyunSEO closed 3 weeks ago

SeunghyunSEO commented 1 month ago

hi, first of all, thank you for the nice opensource project ! i just have been reading your model code and found it initialize model weights following num_layers or layer_id. it is not conventional like kaming init (std=1/\sqrt{fan_in}) or GPT-2 init (std=0.02). and it also dost not look like MuP or something. so i just want to know if there are any references or it's just empirical for training stability.

edit) i forgot std=0.02/\sqrt{depth} init for output layers of residual block is from GPT-2 paper. sry ! just wondering where depth_init is from

wanchaol commented 1 month ago

@lessw2020

lessw2020 commented 3 weeks ago

Hi @SeunghyunSEO - sorry for the delay, didn't see this earlier. To your question - the depth init came about from research last summer when we we were doing work on parallel attention blocks. I did a comparison sweep and adding the depth init was the winner, so have continued to use that. I'm not sure where the concept came from though - it bubbled up in discussions with IBM research a while back. I did see that Olmo was also using this, and they referred to it as a "mitchell init" but I was not able to find anything in arxiv on it. Anyway, short answer is it's empirically based. We haven't done a sweep on it though since llama3 came out so maybe we will revisit it in the future but it continues to perform well in our training runs. Hope that helps!

SeunghyunSEO commented 2 weeks ago

thank you for the kind answer @lessw2020 ! i guess it makes sense because Mitchell init's layerwise output variance would be more consistent compared to GPT init