Zefan-Cai / PyramidKV

The Official Implementation of PyramidKV: Dynamic KV Cache Compression based on Pyramidal Information Funneling
https://arxiv.org/pdf/2406.02069
MIT License
480 stars 43 forks source link

对比了下StreamingLLM源码实现,貌似PyramidKV中的实现位置编码这块有问题 #19

Open free-stardust opened 1 month ago

free-stardust commented 1 month ago

你好,PyramidKV 这个实现确实很巧妙,但是有个问题:我看了下 StreamingLLM 的源码,这个实现中,由于位置编码重新排列,所以注意力实现那里在存 kv cache 之前,并没有对 k 进行位置嵌入,在存 kv cache 后才对 k 进行了临时的位置嵌入和注意力计算;但是 PyramidKV 对应的 StreamingLLM 实现,是在对 q 和 k 同时进行位置嵌入之后进行了kv cache 保存和注意力计算,结果就导致没有实现 StreamingLLM 中的 kv cache 对应位置的重新排列,省略了这块后作为对比测试,是不是有点问题?

Zefan-Cai commented 1 month ago

我推测原来的StreamingLLM为了多次编码4k长度的文本才进行如你所说的操作:在存 kv cache 后才对 k 进行了临时的位置嵌入和注意力计算。因为多次编码的话,在编码中如果加入位置编码,会导致最终前面4个token被多次添加位置编码。如果是进行longBench的测试的话,不会进行多次编码,所以和SnapKV、PyramidKV保持一样的操作应该页合理,也就是:对 q 和 k 同时进行位置嵌入之后进行kv cache 保存和注意力计算。

free-stardust commented 1 month ago

我描述的好像有点问题,我的理解是这样的:

因为 StreamingLLM 原始版本实现里面有个 pos shift,用论文中的例子说就是一个长为 8 的序列移除其中的 4 和 5 之后,各个 token 的位置 pos_id 就变成了这样 [0, 1, 2, 3, 6, 7, 8],这种情况下计算 k 的位置编码时,就按照当前 kv cache 的当前内容长度,将后面断续的位置 pos_id 重新进行了连续性替换,变成了 [0, 1, 2, 3, 4, 5, 6],然后最新生层的 k,编进去则为 7,而不是默认实现对应的位置 id=9,之后基于 [0, 1, 2, 3, 4, 5, 6, 7] 这个序列对 kv cache 进行新的的位置编码。

如果这个序列不重新配置,那么就是 transformers 库默认的实现,即默认的 kv cache 是利用 [0, 1, 2, 3, 6, 7, 8] 这个 pos_id 值进行的位置编码,然后新生成时,对于 pos_id=9 的 token,则是按照 9 进行编码。

因此默认情况下位置编码 k 时,即使 kv cache 中的 token 会被逐渐被 evict,但他们位置编码 pos_id 实际就是各种变动的不连续序列,而 StreamingLLM 对应的 pos shift 实现里面为了达到位置动态配置,保存的 k 是无位置编码的,只在注意力计算时才会按照当前的 kv cache 长度对整个 kv cache 中的 k 进行位置编码嵌入,这个如果在存 kv cache 前就进行位置编码的话,确实会重复添加位置编码,但这种种情况下,不止是前面 4 个会重复添加位置编码,后续的所有 token 都会被重复添加位置编码,但这实际也可以通过切片进行规避,而且默认的注意力实现也只是计算最新的 token 的 k 并对其进行编码存入 cache,所以我理解的是,他这么做的目的,就是为了对 evict 之后的 kv cache 的各个 token 相对位置重新配置,所以进行了这个操作。

但如果不进行这个重新配置的话,就是采用默认的注意力实现,位置 pos_id 不会重新配置,就是按照正常的 pos_id 进行对 k 进行位置编码,其实确实也可以,毕竟 StreamingLLM 把这个 pos shift 作为了可配置项。

然后纠结这个是因为我自己测试的时候,发现他这个 trick 效果确实会好一点,而且他论文里面也对这个进行了详细叙述,但是看 PyramidKV 的实现里面没有这个部分,有点疑惑。

Zefan-Cai commented 1 month ago

我确实是没注意到这个trick。这个问题应该分为两个方面:

1.保存的kv cache不添加位置编码,应该确实是因为会引入重复添加位置编码的问题。如果每次保存都添加位置编码,前四个token和local window都会被多次添加位置编码;而且也不方便在保存后在下一次encode时输入pos ⇧以后得到的新的位置编码。这个操作和其他的kv cache compression工作是不同的。其他的单次编码的工作可以在保存以前就给kv cache添加位置编码,不会影响结果。

  1. 位置编码偏移可能确实是一个trick,我之前没有注意。其他的工作一般没有使用改动的位置编码。直觉上讲,不改动位置编码效果应该会好一些,因为这样模型能对每个token的位置有正确的感知。StreamingLLM使用这个trick的原因,我推测是它需要用4k context length的LLM编码4k长度的内容多次,如果使用没有pos shift的话可能会使用比4k大得多的位置编码。这样的话,LLM会接触到一个没有在训练时见过的位置编码,因此对效果有负面影响。
free-stardust commented 1 month ago

确实,如果按照 StreamingLLM 这种实现,确实会存在重复位置编码的情况,不过从他这个代码逻辑上看,他这种实现应该就是为了服务他的 trick,相比之下其他的 kv cache 压缩是在默认的增量更新基础上进行的修改,所以不会影响结果。

然后关于第二方面,这个我也感觉很奇怪,按 RoPE 的原理来说,应该是保存相对位置更好,并且这个理论上是具有更好的外推性,但是在 StreamingLLM 重新编码位置后,相对位置信息改变了,生成效果居然好了,这确实有点神奇,我第一反应是位置编码算法特性的问题,您说的这个推测,我还真没想到,学习了,感谢感谢!