HKUNLP / ChunkLlama

[ICML'24] Data and code for our paper "Training-Free Long-Context Scaling of Large Language Models"
Apache License 2.0
341 stars 18 forks source link

An error occurred when I used flash_decoding_chunkllama in run_chunkllama_100k.py #17

Open smilelite opened 4 months ago

smilelite commented 4 months ago

I used flash decoding in run_chunkllama_100k.py like

from chunkllama_attn_replace import replace_with_chunkllama

from flash_decoding_chunkllama import replace_with_chunkllama

the model is Llama2-70B or Llama3-70B transformers==4.40.1 torch==2.2.1

the Error is File "/mnt/ChunkLlama/flash_decoding_chunkllama.py", line 327, in forward key_cache[:, kv_seq_len - key_states.shape[-2]:kv_seq_len, :, :] = key_states.transpose(1, 2) RuntimeError: The expanded size of the tensor (1024) must match the existing size (128) at non-singleton dimension 3. Target sizes: [1, 26141, 8, 1024]. Tensor sizes: [26141, 8, 128]

By the way, I added "**kwargs," in LlamaModel_forward and replaced "if self._attn_implementation:" by "if self.config._attn_implementation == "flash_attention_2":" .

ChenxinAn-fdu commented 4 months ago

Hi ! Does it work for Llama2-7B and Llama3-8b based on your environment?

smilelite commented 4 months ago

When I use replace_with_chunkllama from chunkllama_attn_replace.py based on Llama2-7B、Llama3-8B、Llama2-70B、LLama3-70B, it runs normally. However, when I use replace_with_chunkllama from flash_decoding_chunkllama, it doesn't work.

smilelite commented 4 months ago

Llama2-7B is ok. The error was caused by the use of GQA, which resulted in inconsistent calculation of the head dim in flash_decoding_chunkllama.py and the headdim in modeling_llama.py.

ChenxinAn-fdu commented 4 months ago

Thank you so much for letting me know! I will update the code to support GQA😊

ChenxinAn-fdu commented 3 months ago

So sorry for the late response. I was too busy in the past two weeks. The code works well with Llama3 now. Plz try it!