ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.3k stars 898 forks source link

PR: Add KV-cache creation capability to mlx_lm.generate for after a text completion #1001

Closed mark-lord closed 1 month ago

mark-lord commented 2 months ago

Whilst mlx_lm.cache_prompt lets you encode a prompt and save the key value pairs as a kv-cache.safetensors file in advance, there's currently no means of saving the kv-cache after a text completion by an LLM.

Adding this will bring MLX_lm more in line with Llama.cpp in terms of reducing latency in multi-turn scenarios. For example, using an MLX-served model as a chatbot and having a drawn out discussion about a given topic. Saving the KV-cache after each turn by the LLM means that even as the conversation history continues, there won't be any latency introduced by having to re-encode the entire chat log again - only the most recent user prompt.

Not sure I went about it the best way, but it seems to work from my testing! There's one superfluous edit to the step generator line 357 which can probably be left out, but otherwise I think I kept this as streamlined as I could.

mark-lord commented 2 months ago

Oop, forgot to add explanation of how to use. Copied + pasted from my explanation in the MLX Discord:

On first time initialisation of a chat you'll need to create the cache first for the model to use, which can be done with the normal cache prompt.py:

mlx_lm.cache_prompt --model 'mlx-community/Llama-3.2-1B-Instruct-4bit' --save-kv-cache "rollingcache.safetensors" --prompt "From now on, talk like a pirate"

And then after you'd do something like

mlx_lm.generate --model 'mlx-community/Llama-3.2-1B-Instruct-4bit' --save-kv-cache "rollingcache.safetensors" --kv-cache-file "rollingcache.safetensors" --prompt "Tell me a joke"

Which creates the generation: Arrr, settle yerself down with a pint o' grog and listen close, me hearty. Here be a joke fer ye: Why did the pirate quit his job? (pause for dramatic effect) Because he was sick o' all the arrrr-guments! (wink) Savvy?

Seems to work perfectly fine on a rolling basis. For example, you can then follow up with:

mlx_lm.generate --model 'mlx-community/Llama-3.2-1B-Instruct-4bit' --save-kv-cache "rollingcache.safetensors" --kv-cache-file "rollingcache.safetensors" --prompt "Explain that joke to me"

To which you'll get the reply, Alright then, matey, settle yerself down with a pint o' grog and listen close. Ye want to know the joke, eh? Alright then, I'll give it to ye. So, I be tellin' ye this one: Why'd the pirate quit his job? (pauses for dramatic effect) He be sick o' all the arrrr-guments, savvy? Arrrr, get it? It's a pirate pun, matey!

By having --save-kv-cache and --kv-cache-file as the same, It'll load the kv-cache file, take your new prompt and generate from it, then overwrite the original kv-cache file. Thereby making a rolling kv-cache. Even as the conversation history grows, time to first token is kept low as there's no need to re-encode.

mark-lord commented 2 months ago

Oop, didn't realise there was already a PR on this - https://github.com/ml-explore/mlx-examples/pull/989 Haven't had time to take a look yet but from what I gather it might be tackling largely the same thing

awni commented 1 month ago

Thanks for the PR! I changed a bit how the caching works in MLX LM to make this exact use case much easier.

You can see the PR #1015 and example there. It will make avoiding recomputing KV caches and serializing them to disk much cleaner / easier.

It should mostly subsume the case you fixed here so I will close this.