ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.24k stars 884 forks source link

Handle longer prompt/generation #931

Closed awni closed 3 months ago

awni commented 3 months ago

This PR introduces two optimizations to handle much longer prompts and generations:

Longer prompts

With stepped prefill:

Prompt: 4232 tokens, 3138.508 tokens-per-sec
Peak memory: 2.640 GB

Without stepped prefill:

Prompt: 4232 tokens, 2300.630 tokens-per-sec
Peak memory: 8.422 GB

Longer generations

Allow rotating KV cache to enable infinite generations. Toggle with the flag --max-kv-size=512 (or whatever value).

This is similar (probably identical) to the behavior in llama.cpp. The first keep tokens (default 4) of the prompt are always kept, everything else gets overwritten in the circular buffer. Based on this paper.

With this technique you can generate indefinitely at fixed memory use.

Jckwind commented 3 months ago

looks great!

awni commented 3 months ago

Back to draft for a few. The rotating buffer doesn't play well with the step prefill for long prompts.. so that needs some work.

Blaizzy commented 3 months ago

This is awesome!

Great work @awni 🔥

awni commented 3 months ago

Ok so I think this can be reviewed and merged. A little note on the "infinite KV cache":

For simplicity it separates the cache growth into two stages: prefill (i.e. prompt processing) and generation. During generation we assume the updates to the cache are one time-step at a time. During the prefill stage, they can be any number.

The invariant is that every new token attends to at least max_size - 1 previous tokens (including the first keep=4).

During prefill, to make this happen the KV cache can grow as big as max_size + step_size - 1 (where step_size is the prefill update step size. To keep things simple we don't use a circular buffer during this stage as the masking can get a bit complicated and the code is not so well setup for that. Instead, the prefill stage simply grows by triming the old cache and concatenating the update as a suffix to create the new cache.

During generation it uses a circular buffer with a fixed size at max_size and maintains an index into the next slot to write into the buffer.

mzbac commented 3 months ago

Thanks for the great work, I have a quick question: will the circular buffer maintain the logical order of the cache? Please correct me if I'm wrong, but it seems in the code we are not maintaining the logical order of the cache. For example, if we start with an initial cache of [1, 2, 3, 4, 5, 6], and keep 1 and 2 as attention sink and the cache during generation, it becomes [1, 2, 7, 4, 5, 6]. self-attention is using [1, 2, 7, 4, 5, 6] instead of [1 ,2 ,4 ,5 ,6 ,7]..

awni commented 3 months ago

You're understanding is exactly right. The logical order doesn't matter in this case, the output is the same since self-attention is invariant to permutations in its input. (Note the RoPE addition is done before the key/values get put into the cache, so the position encodings are still valid).

mzbac commented 3 months ago

You're understanding is exactly right. The logical order doesn't matter in this case, the output is the same since self-attention is invariant to permutations in its input. (Note the RoPE addition is done before the key/values get put into the cache, so the position encodings are still valid).

Thanks for the detailed explanation, really appreciate it. It makes a lot of sense to me now ❤️