Zefan-Cai / PyramidKV

The Official Implementation of PyramidKV: Dynamic KV Cache Compression based on Pyramidal Information Funneling
https://arxiv.org/pdf/2406.02069
MIT License
496 stars 47 forks source link

A serious issue in your code #15

Open JulietLJY opened 3 months ago

JulietLJY commented 3 months ago

There seems to be a serious issue in run_longbench.py. The update_kv is only called during the first sample in longbench, therefore the statement print(f"PyramidKV max_capacity_prompt {max_capacity_prompt}") only outputs during the first sample. This implies that only the first sample uses PyramidKV. Below are the key_states.shape printed for the first sample:

torch.Size([1, 32, 7955, 128])
torch.Size([1, 32, 7834, 128])
...
torch.Size([1, 32, 4204, 128])

Starting from the second sample, update_kv is no longer executed, and the key_states.shape is as follows:

torch.Size([1, 32, 7957, 128])
torch.Size([1, 32, 7957, 128])
...
torch.Size([1, 32, 7957, 128])

This implies that starting from the second sample, all attention is full attention, and there is no compression of kv cache at all. Looking forward to your response.

Zefan-Cai commented 3 months ago

Thank you so much for pointing out!

This is possibly a bug related to transformer versions. I recall that I encountered this issue before, but I don't remember exactly the Transformer version is.

Would you mind changing the Transformer version to 4.41? This is expected to solve this issue. And if it works out, the key_state.shape should be as follows for all the samples: ''' debug key_states.shape torch.Size([1, 32, 243, 128]) debug key_states.shape torch.Size([1, 32, 236, 128]) debug key_states.shape torch.Size([1, 32, 26, 128]) '''

If you are already using Transformers 4.41, please let me know. I would try to figure out.

JulietLJY commented 3 months ago

Thank you for your timely response. After updating the transformers version to 4.41, the issue has been resolved. I suggest you consider advising users about this issue, as it doesn’t produce any error messages and could easily lead to confusing results.

Zefan-Cai commented 3 months ago

Thank you so much for pointing out.

We would add special advice as you suggest. This is indeed confusing.

liuxiaozhu01 commented 3 months ago

I have a question about self.kv_seq_len in here. Im confused about where it is initialized and set to 0 when here comes a new batch. If it is not set to 0 when a new batch comes, it seems to result that update_kv wont be called just like @JulietLJY has mentioned earlier. Is it related to transformer versions?

Zefan-Cai commented 3 months ago

I have a question about self.kv_seq_len in here. Im confused about where it is initialized and set to 0 when here comes a new batch. If it is not set to 0 when a new batch comes, it seems to result that update_kv wont be called just like @JulietLJY has mentioned earlier. Is it related to transformer versions?

It is initialized at https://github.com/Zefan-Cai/PyramidKV/blob/73c08b1dc1104b2d614c0670478d297a7a4df8c1/pyramidkv/llama_model.py#L1382. This function was to replace the original preparation function with monkey patch. It will take care of the kv_seq_len initialization. Without this function, it will result in the situation as you mentioned. So this is probably not because of transformer versions.

liuxiaozhu01 commented 3 months ago

I have a question about self.kv_seq_len in here. Im confused about where it is initialized and set to 0 when here comes a new batch. If it is not set to 0 when a new batch comes, it seems to result that update_kv wont be called just like @JulietLJY has mentioned earlier. Is it related to transformer versions?

It is initialized at

https://github.com/Zefan-Cai/PyramidKV/blob/73c08b1dc1104b2d614c0670478d297a7a4df8c1/pyramidkv/llama_model.py#L1382

. This function was to replace the original preparation function with monkey patch. It will take care of the kv_seq_len initialization. Without this function, it will result in the situation as you mentioned. So this is probably not because of transformer versions.

Oh! Thank you for your timely reply. It's my carelessness. :)