FasterDecoding / SnapKV

139 stars 4 forks source link

Questions on paper and code [prompting for mistral, positional index, minor errors & questions in paper] #1

Open MarsJacobs opened 2 months ago

MarsJacobs commented 2 months ago

Hello :) Thank you for the excellent work and for sharing your code. I've learned a lot and have a few questions about the paper and settings:

Additionally, there seems to be a minor error in Figure 7 where both the top and bottom plots are labeled as "without Pooling." It might be less confusing to label the bottom plot as "with Pooling."

Thank you for any insights you can provide. I really appreciate the motivation and methodology behind your work!

leeyeehoo commented 2 months ago

Hi Minsoo,

Thank you for your suggestion! I am working on polishing this repository. For the first 2 questions, I am working on providing the support code for y'all to rerun our experiments (give me the roughly estimated time of around 1-2 days). For the third question, as we informed in the last section, the SnapKV has its own drawback such as it cannot do compression on prompt phase. This is a very challenging issue and we acknowledge the limitations of SnapKV. Such query-dependent compression can be less effective when the prompt length is enormously large (let's say 1m). We are working on some experiments to see if there are better ways to reduce the computation of prompts. That is why you encounter the OOM because it is the native implementation of attention and needs O(n^2) for prompt.

Once I commit and polish the repo I will notify you :)

guozhiyu commented 2 months ago

The integration with flash attention may be important for an extremely long context. In the Github of h2o, there are several issues about this. Unfortunately, it hasn't been solved yet.

leeyeehoo commented 2 months ago

We do support the flash attention. For calculating the top-k by distribution, we don't have a fused implementation yet, since we benchmarked with 300k len seq there is no significant overhead from this part.

MarsJacobs commented 2 months ago

Hello!

I've been looking into your code implementation and I have a question regarding the positional index. It seems that the eviction of tokens in the input context is handled with the positional index already encoded (similar to what's done in H2O).

For example, if we have tokens T1, T2, T3, and T4 with positional indices 0, 1, 2, and 3 respectively, and T3 gets evicted, the remaining T1, T2, and T4 would have positional indices of 0, 1, and 3.

However, if we apply positional encoding after eviction, then the indices for T1, T2, and T4 would be adjusted to 0, 1, and 2, similar to the rolling KV cache mechanism used in streaming LLMs.

Do you think there would be any performance benefits to adopting this approach in the SnapKV system? Have you conducted any experiments related to this? I'm curious to hear your thoughts.

Also, I wanted to mention that I've read your reply and am looking forward to the commit. Thanks as always for your prompt and helpful responses 😄

leeyeehoo commented 2 months ago

I got you. You can definitely try this (We don't have the experiment or implementation for pre-rope eviction).

From my VERY personal view, for new, incoming LLMs, the context length will grow significantly. As you can see, Llama-3 with some slight ft can obtain long-context abilities: (https://huggingface.co/gradientai). So it could be trivial to use such strategies to extend the context length. What (I think) we need to address is how to efficiently handle the near-lossless long-context serving...

Another thing worth notice is, some evidence also shows when serving the model, the kv cache will be dropped after a round of chat. You can try Claude-3 Opus for long document qa, it is very slow no matter if you have a previous prompt or so... and it limits your trials (like 10 times per 5 hours ig?). So we are exploring how to compress the whole prompt (which is missing in this version). Otherwise, when the context length > 100k, the overhead is from the prompt phase instead of generation.

leeyeehoo commented 2 months ago

[prompting for mistral]

The reason is we asked the author of LLM Maybe LongLM. He told us there are performance discrepancies for Mistral-Ins-V0.1 and suggested we remove the instruction template. You should be able to reproduce the results of our experiments in experiments/LongBench. You can put the instruction template back (which was commented). If there are inconsistencies feel free to check with me.

MarsJacobs commented 2 months ago

Thank you for your valuable insights and detailed explanation about positional embeddings! I tried running the new code you committed with mistral-v0.2-32k and achieved performance nearly identical to the reported paper results on Qasper in Transformers==4.38.2

SnapKV-4096 reproduce: 33.38, reported in paper: 33.36

I do have one question regarding your mention of "compress the whole prompt." Could you provide more details on what this involves? As I understand it, SnapKV already applies compression across the entire input context. I'm curious about what distinguishes the point you mentioned from the existing method.

Also, I have a minor question regarding Figures 2, 3, and 4: Is the y-axis supposed to represent percentages or ratios? The axis is labeled with a '%' but the scale goes up to 1.0, which seems like a small error. Should I interpret it as 100%?

Additionally, in Table 4, the GT positions are listed as 0, 14, 30 regardless of the number of documents, which might also need some adjustment. I hope this feedback is helpful as you prepare the revised manuscript!

leeyeehoo commented 2 months ago

I see the typo in Table 4. I will add the observation experiment next week since I am preparing the defense these weekend :)