princeton-nlp / LLM-Shearing

[ICLR 2024] Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning
https://arxiv.org/abs/2310.06694
MIT License
533 stars 39 forks source link

LlamaRMSNorm() layer differs from original llama #63

Closed suhmily closed 5 months ago

suhmily commented 5 months ago

I noticed that the structure of the llama model is LlamaForCausalLM( (model): LlamaModel( (embed_tokens): Embedding(32000, 5120, padding_idx=0) (layers): ModuleList( (0-39): 40 x LlamaDecoderLayer( (self_attn): LlamaAttention( (q_proj): Linear(in_features=5120, out_features=5120, bias=False) (k_proj): Linear(in_features=5120, out_features=5120, bias=False) (v_proj): Linear(in_features=5120, out_features=5120, bias=False) (o_proj): Linear(in_features=5120, out_features=5120, bias=False) (rotary_emb): LlamaRotaryEmbedding() ) (mlp): LlamaMLP( (gate_proj): Linear(in_features=5120, out_features=13824, bias=False) (up_proj): Linear(in_features=5120, out_features=13824, bias=False) (down_proj): Linear(in_features=13824, out_features=5120, bias=False) (act_fn): SiLUActivation() ) (input_layernorm): LlamaRMSNorm() (post_attention_layernorm): LlamaRMSNorm() ) ) (norm): LlamaRMSNorm() ) (lm_head): Linear(in_features=5120, out_features=32000, bias=False),

while the structure of the converted model is ComposerMosaicLlamaAlibi( (model): LlamaModel( (transformer): ModuleDict( (wte): Embedding(32000, 5120) (blocks): ModuleList( (0-39): 40 x LlamaBlock( (ln_1): LlamaRMSNorm() (attn): LlamaAttention( (wq): Linear(in_features=5120, out_features=5120, bias=False) (wk): Linear(in_features=5120, out_features=5120, bias=False) (wv): Linear(in_features=5120, out_features=5120, bias=False) (out_proj): Linear(in_features=5120, out_features=5120, bias=False) ) (ln_2): LlamaRMSNorm() (mlp): LlamaMLP( (gate_proj): Linear(in_features=5120, out_features=13824, bias=False) (down_proj): Linear(in_features=13824, out_features=5120, bias=False) (up_proj): Linear(in_features=5120, out_features=13824, bias=False) ) ) ) (ln_f): LlamaRMSNorm() (output): Linear(in_features=5120, out_features=32000, bias=False) ) ) ). The position of the LlamaRMSNorm() layer is different from the original. Is this expected?

xiamengzhou commented 5 months ago

print(model) displays the sequence of modules in the order they are defined and initialized within the model. These normalization functions should be used at the same positions if you check the forward function.