Open smilelite opened 4 months ago
Hi ! Does it work for Llama2-7B and Llama3-8b based on your environment?
When I use replace_with_chunkllama from chunkllama_attn_replace.py based on Llama2-7B、Llama3-8B、Llama2-70B、LLama3-70B, it runs normally. However, when I use replace_with_chunkllama from flash_decoding_chunkllama, it doesn't work.
Llama2-7B is ok. The error was caused by the use of GQA, which resulted in inconsistent calculation of the head dim in flash_decoding_chunkllama.py and the headdim in modeling_llama.py.
Thank you so much for letting me know! I will update the code to support GQA😊
So sorry for the late response. I was too busy in the past two weeks. The code works well with Llama3 now. Plz try it!
I used flash decoding in run_chunkllama_100k.py like
from chunkllama_attn_replace import replace_with_chunkllama
from flash_decoding_chunkllama import replace_with_chunkllama
the model is Llama2-70B or Llama3-70B transformers==4.40.1 torch==2.2.1
the Error is File "/mnt/ChunkLlama/flash_decoding_chunkllama.py", line 327, in forward key_cache[:, kv_seq_len - key_states.shape[-2]:kv_seq_len, :, :] = key_states.transpose(1, 2) RuntimeError: The expanded size of the tensor (1024) must match the existing size (128) at non-singleton dimension 3. Target sizes: [1, 26141, 8, 1024]. Tensor sizes: [26141, 8, 128]
By the way, I added "**kwargs," in LlamaModel_forward and replaced "if self._attn_implementation:" by "if self.config._attn_implementation == "flash_attention_2":" .