jzhang38 / TinyLlama

The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.
Apache License 2.0
7.31k stars 426 forks source link

Why change the `_init_weights` ? #117

Closed larrylawl closed 6 months ago

larrylawl commented 6 months ago

Hi, I noticed that the team changed the _init_weights function in this commit: https://github.com/jzhang38/TinyLlama/commit/89c75f40ec99cb48ea94c92423c1deae67bc2329

Can I check the motivation for doing so? Thanks!

jzhang38 commented 6 months ago

The previous implementation only scales the MLP output during init. We decide to follow the init methods of https://github.com/kingoflolz/mesh-transformer-jax to also scale the attention output. (check GPT-NeoX-20B: An Open-Source Autoregressive Language Model Section 2.1.3)

larrylawl commented 6 months ago

Thanks @jzhang38 for replying!

Can I clarify why is this change made? When I printed the module.weight.size(1) in the elif isinstance(module, nn.Linear) block, they vary: they can either be n_embd or intermediate_size. Given that the small_init scheme is a function of the dimension, shouldn't it be module.weight.size(1) instead of n_embd?

image image

For the following line, shouldn't the numerator be 2 instead of 1?

https://github.com/jzhang38/TinyLlama/blob/b4964afba313499e5e6e15e8573c7648982c8896/lit_gpt/model.py#L54

image

Sorry peiyuan, can you point me to the code which implemented this? I can't seem to find it within the codebase.

We decide to follow the init methods of https://github.com/kingoflolz/mesh-transformer-jax to also scale the attention output.

larrylawl commented 6 months ago

Bump jicymi @jzhang38

jzhang38 commented 6 months ago

Hi Larry sorry for the late reply. I am busy with another project these days. Next time you can ping me on Twitter if I forget to respond.

For the following line, shouldn't the numerator be 2 instead of 1?

GPT-Neo X uses parallel attention and FF layer, and its section 2.1.3 says "with the factor of 2 compensating for the fact that the parallel and feed-forward layers are organized in parallel". Since Llama use sequential attention & FF, the numerator should be 1.

When I printed the module.weight.size(1) in the elif isinstance(module, nn.Linear) block, they vary: they can either be n_embd or intermediate_size. Given that the small_init scheme is a function of the dimension, shouldn't it be module.weight.size(1) instead of n_embd?

This is a really good question. I have the same doubts as you initially. In the end, I decided to follow the implementation of GPT-Neo X, where they use the n_embd value instead of weight.size(1). https://github.com/EleutherAI/gpt-neox/blob/f14782a571b9b4ff52803ce57c2bfc650670c30a/megatron/model/init_functions.py#L204

I think reading section 2.2 of Transformers without Tears: Improving the Normalization of Self-Attention could solve your double. In short, the reason why we use "5d" is because the sum of input and output dim for the FF upsample/downsample layer is 5d. For attention qkv, this sum is "2d". In that paper, the author proposes to initialize attention qkv as 5d as well, which is why the name "small init". Note that d here refers to the transformers's n_embed.

(On the other hand, the hidden dim of swiglu is actually 8/3d instead of 4d. You can see this is another issue.)

larrylawl commented 6 months ago

No worries at all. Thanks Peiyuan!