Closed awni closed 3 months ago
looks great!
Back to draft for a few. The rotating buffer doesn't play well with the step prefill for long prompts.. so that needs some work.
This is awesome!
Great work @awni 🔥
Ok so I think this can be reviewed and merged. A little note on the "infinite KV cache":
For simplicity it separates the cache growth into two stages: prefill (i.e. prompt processing) and generation. During generation we assume the updates to the cache are one time-step at a time. During the prefill stage, they can be any number.
The invariant is that every new token attends to at least max_size - 1
previous tokens (including the first keep=4
).
During prefill, to make this happen the KV cache can grow as big as max_size + step_size - 1
(where step_size
is the prefill update step size. To keep things simple we don't use a circular buffer during this stage as the masking can get a bit complicated and the code is not so well setup for that. Instead, the prefill stage simply grows by triming the old cache and concatenating the update as a suffix to create the new cache.
During generation it uses a circular buffer with a fixed size at max_size
and maintains an index into the next slot to write into the buffer.
Thanks for the great work, I have a quick question: will the circular buffer maintain the logical order of the cache? Please correct me if I'm wrong, but it seems in the code we are not maintaining the logical order of the cache. For example, if we start with an initial cache of [1, 2, 3, 4, 5, 6], and keep 1 and 2 as attention sink and the cache during generation, it becomes [1, 2, 7, 4, 5, 6]. self-attention is using [1, 2, 7, 4, 5, 6] instead of [1 ,2 ,4 ,5 ,6 ,7]..
You're understanding is exactly right. The logical order doesn't matter in this case, the output is the same since self-attention is invariant to permutations in its input. (Note the RoPE addition is done before the key/values get put into the cache, so the position encodings are still valid).
You're understanding is exactly right. The logical order doesn't matter in this case, the output is the same since self-attention is invariant to permutations in its input. (Note the RoPE addition is done before the key/values get put into the cache, so the position encodings are still valid).
Thanks for the detailed explanation, really appreciate it. It makes a lot of sense to me now ❤️
This PR introduces two optimizations to handle much longer prompts and generations:
Longer prompts
state
method to the KV cache classes so we canmx.eval
only the cache stateWith stepped prefill:
Without stepped prefill:
Longer generations
Allow rotating KV cache to enable infinite generations. Toggle with the flag
--max-kv-size=512
(or whatever value).This is similar (probably identical) to the behavior in llama.cpp. The first
keep
tokens (default 4) of the prompt are always kept, everything else gets overwritten in the circular buffer. Based on this paper.With this technique you can generate indefinitely at fixed memory use.