huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.67k stars 25.52k forks source link

how to remove kv cache? #31717

Open tsw123678 opened 6 days ago

tsw123678 commented 6 days ago

Feature request

When I use the generate() function of a language model for inference, the kv-cache is also stored in the GPU memory. Is there any way to clear this kv-cache before continuing to call generate()?

Motivation

I have a lot of text to process, so I use a for loop to call generate(). To avoid OOM, I need to clear the kv-cache before the end of each loop iteration.

Your contribution

none

amyeroberts commented 5 days ago

Hi @tsw123678, you should be able to pass in use_cache=False when calling generate. Note, this will result in significant slow downs when generating

tsw123678 commented 5 days ago

Hi @tsw123678, you should be able to pass in use_cache=False when calling generate. Note, this will result in significant slow downs when generating

thank you for your tips. since I need to handle a large number of conversations (over 10,000), I can't afford to sacrifice the time for kv cache. I just want to know if there is a way to clear the current kv cache after generate function.

amyeroberts commented 5 days ago

Hi @tsw123678, to be able to help, you would need to clarify the construction of the loop i.e. what's being looped over.

I'm I right in understanding that the GPU memory is not freed after a generate call?

cc @gante

tsw123678 commented 5 days ago

Hi @tsw123678, to be able to help, you would need to clarify the construction of the loop i.e. what's being looped over.

I'm I right in understanding that the GPU memory is not freed after a generate call?

cc @gante

year, my loop code can be simplified as follows:

user_prompt=[prompt 1, prompt 2, prompt 3 .... prompt n]

for prompt in user_prompt:
    input_ids=tokenizer.encode(prompt)
    out=model.generate(prompt)
    res=tokenizer.decode(out)
    # some IO operation to save the result

    del input_ids,out,res
    torch.cuda.empty_cache()

I've noticed that deleting variables such as input_ids that are located on CUDA does not prevent the OOM issue. After my analysis, I believe that the accumulation of the KV cache is the cause of the OOM problem. Thank you very much for your help.

swtb3 commented 2 days ago

Hi folks, has there been any development on this? I am having an issue, described here: https://discuss.huggingface.co/t/kv-cache-managment/95481

I believe it is down to the static KV cache. Note I do not seem to experience any OOM and GPU utilisation stays at 20Gb/40Gb. Inference does just stop and leaves the process hanging. It would be nice to see some more practical documentation around this feature as it is not clear how to use/manage it beyond typical use.

Edit: in my case I use the text generation pipeline. The problem is present for Mistral 7B and Llama3-7B, the smaller Phi-3 (3.8B) does not bump into this issue.

Edit 2: Please disregard my comment, I have narrowed the issue down to Chromadb so it isnt relevant here. Though I am still interested in seeing an option to clear the cache

amyeroberts commented 1 day ago

Also cc @ArthurZucker re cache