FasterDecoding / SnapKV

141 stars 4 forks source link

It seems that snapkv need to be able to do "prefill" at least once before the prompt can be compressed. #9

Closed 66RING closed 2 months ago

66RING commented 2 months ago

snapkv need a full len q, k matmul before its first self-attention, which is a $O(n^2)$ space complexity. So is snapkv need to be able to do "prefill" at least once before the prompt can be compressed?

after that it can save memory footprint during decoding phase.

   def update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):
        # check if prefix phase
        assert key_states.shape[-2] == query_states.shape[-2]
        bsz, num_heads, q_len, head_dim = query_states.shape
        if q_len < self.max_capacity_prompt:
            return key_states, value_states
        else:
            attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)
66RING commented 2 months ago

my bad. not that full len. space complexity is window_size x full_len in this case.