datamllab / LongLM

[ICML'24 Spotlight] LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning
https://arxiv.org/pdf/2401.01325.pdf
MIT License
549 stars 54 forks source link

Run example.py Error: Failed to modify the attention method of LlamaForCausalLM #39

Closed tuzeao-tal closed 1 month ago

tuzeao-tal commented 1 month ago

Hello. I just simplily run the example.py and met the error in the "=====SelfExtend using Torch======" part:

Traceback (most recent call last):
  File "./LongLM/example.py", line 112, in <module>
    SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)
  File "./LongLM/SelfExtend.py", line 123, in apply
    raise Exception(f"Failed to modify the attention method of {arch_name}")
Exception: Failed to modify the attention method of LlamaForCausalLM

transformers=4.41, flash_attn=2.5.8

Meanwhile, I have noticed the similar problem https://github.com/datamllab/LongLM/issues/31, So I tried setting the attention not be flash attention in the same time: model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, use_flash_attention_2=False)

SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)

I printed the model, which shows it's not the flash attention:

  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

So where is the problem?

Mooler0410 commented 1 month ago

This problem may come from: in transfromers==4.41, the default Attention module for Llama models are changed from LlamaAttention / LlamaFlashAttention2 to LlamaSdapAttention. Hence, the forward function modification will fail . You may modify these lines: line1, line2 to fix this problem. But we are not sure whether other parts work well with transfromers==4.41. You'd better use transformers==4.38.2 or transformers==4.40.

tuzeao-tal commented 1 month ago

Seems good. Thanks. I will try it again later.