openlm-research / open_llama

OpenLLaMA, a permissively licensed open source reproduction of Meta AI’s LLaMA 7B trained on the RedPajama dataset
Apache License 2.0
7.27k stars 370 forks source link

How did you initialize llama? #98

Open brando90 opened 5 months ago

brando90 commented 5 months ago

My code

ef reinitialize_weights_gpt_neox_20B_inspired_4_llama2(model):
    """
    Note: we nearly gpt-neox_20B (2022) & llama1 , llama2 (2019) does not say how they init

    I think gpt-neox_20B & llama2 both have pre-layernorm, because transformers without tears uses the init that gpt-neox-20B uses and llama1 says it uses prenorm,
    so maybe both pre-layernorm.
    Thus, I hope transformers without tears init/the same that transformers without tears uses works. 

    Init:
    FF layer: (as Wang 2021, not transformers without tears)
        -> W ~ N(0, 3/L * sqrt(D))
        decided that cuz 2021 is later than transformers without tears (2019 Nguyen, Salazer)
    Other layers (as transformers without tears(2019 Nguyen, Salazer)):
        -> W ~ N(0, sqrt(2 / (d + 4d)))
    norm_layer
        gpt-neox_20B: uses layer_norm
        llama2 uses llama1 which uses: RMSNorm (Zhang and Sennrich (2019))
        decided not to copy gpt-neox_20B (which uses W ~ N(0, sqrt(2 / (d + 4d)))) 
        because they don't share the same norm. llama1/2 use RMSnorm:
            mean_a_i = g_i * a_i / sqrt(1/n sum_j a_j^2 ) [where is eps?]
        So I decided
        -> g_i (gain) ~ constant(1)
        since we aren't training to completion so perhaps it helps at the beginning. If it diverges we can set this to small or what gpt-neox_20B uses.
        There is no offset, but I will set it to 0 in the code WLG.
    Activation:
        SwiGLU (not relu for llama1, llama2) [us for baby llama2]
        gpt-neox_20B uses...doesn't say.
    We use normal distribution because transformers without tears uses it & since gpt-neox_20B uses nearly same inits llama2 likely does too. 

    refs: rmsnorm https://arxiv.org/pdf/1910.07467.pdf
    refs: llama1 since llama2 uses same arch https://arxiv.org/pdf/2302.13971.pdf 
    ref: pytorch inits https://pytorch.org/docs/stable/nn.init.html

    ref: llama2 7b config: https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L13 
    ref: later https://discuss.huggingface.co/t/how-to-choose-std-for-weight-init-for-llama-2-after-reinitialize/69702

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 96, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=96, out_features=96, bias=False)
          (k_proj): Linear(in_features=96, out_features=96, bias=False)
          (v_proj): Linear(in_features=96, out_features=96, bias=False)
          (o_proj): Linear(in_features=96, out_features=96, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=96, out_features=11008, bias=False)
          (up_proj): Linear(in_features=96, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=96, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=96, out_features=32000, bias=False)
)
    return get_smaller_llama2(hidden_size=32*3, num_hidden_layers=32, verbose=verbose)
    so in_featres = 96 ==> D=96
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            D = module.in_features  # I think this is right size it's xW []
            L = module.weight.shape[1]
            nn.init.normal_(module.weight, mean=0, std=3 / (L * (D)**0.5))
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif 'norm' in name.lower() or 'norm' in str(module).lower():
            if module.weight is not None:  # todo: idk if needed for layer norm
                nn.init.constant_(module.weight, 1)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        else:
            D = module.weight.shape[0]
            L = module.weight.shape[1]
            nn.init.normal_(module.weight, mean=0, std= (2 / (D + 4*D))**0.5 )
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)