huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.79k stars 26.96k forks source link

Static KV cache status: How to use it? Does it work for all models? #33270

Open oobabooga opened 2 months ago

oobabooga commented 2 months ago

I see that there are many PRs about StaticCache, but I couldn't find a clear documentation on how to use it.

What I want

Who can help?

Maybe @gante

zucchini-nlp commented 2 months ago

Some useful resources from the docs:

gante commented 2 months ago

@oobabooga the links @zucchini-nlp shared should enable what you want with respect to having a static memory. I'm not 100% sure StaticCache works with all quantization types, as they sometimes rewrite parts of the forward pass. Let us know if the information is not clear or if you find issues 🤗

Also on the topic of memory: we've merged two PRs recently that should lower memory requirements, regardless of the cache type. No action required from a user's point of view:

Oxi84 commented 2 months ago

I actually see quite big difference in peak memory usage between 4.42.3 and 4.44.2, and also 4.44.2 is 20-30 percent faster.

oobabooga commented 1 month ago

The documentation is not clear. The first link recommends doing

model.generation_config.cache_implementation = "static"

The second one recommends passing a kwarg to model.generate:

# simply pass the cache implementation="static"
out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="static")

The third one defines StaticCache and provides a forward pass example, but it doesn't concern .generate.

How do I explicitly specify the maximum sequence length for the static cache before using .generate and before the cache memory gets allocated?

gante commented 1 month ago

@oobabooga

cache_implementation creates a cache whose maximum size is defined by the prompt length and max_new_tokens. If you call it repeatedly but e.g. increase max_new_tokens, it will instantiate a new (larger) cache. You can see the cache_implementation argument as a low-effort, low-flexibility flag.

How do I explicitly specify the maximum sequence length for the static cache before using .generate and before the cache memory gets allocated?

For complete flexibility, you should instantiate a Cache object outside generate and pass it through the past_key_values argument (as in this example). This is the best solution if you expect different generate to be called with different input sizes and max_new_tokens values :)


regarding difficulty in parsing information: what do you think would help? Perhaps separating the docs into basic and advanced usage?

oobabooga commented 1 month ago

So allocations only happen once per .generate call -- that's good enough, my concern was with repeated memory allocations for each token. Thanks for the clarification.

I have tested the speed with/without static cache for a 7B model, and could not find a speed improvement by using cache_implementation="static" by itself, so I assume the main benefit is when this is combined with torch.compile.

Before

Prompt processing (t/s): 7628.15
Text generation (t/s): 47.63

After

Prompt processing (t/s): 7256.68
Text generation (t/s): 42.26

regarding difficulty in parsing information: what do you think would help? Perhaps separating the docs into basic and advanced usage?

I think that the exact clarifications in your last comment would help if included in the documentation.

My main interest in static cache is the fact that when ExLlama (v1) got introduced, I could fit a lot more context in my 3090 than with AutoGPTQ + transformers. But maybe the excess memory usage was unrelated to static cache, as I don't see a VRAM difference with or without cache_implementation="static". For reference, this is how static cache is used in the ExLlamaV2 generator:

https://github.com/turboderp/exllamav2/blob/03b2d551b2a3a398807199456737859eb34c9f9c/examples/inference.py#L12

ArthurZucker commented 1 month ago

Yep, if you want to fit longer context I would recommend the OffloadedStaticCache! see here

github-actions[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

gante commented 1 week ago

@oobabooga yeah, StaticCache by itself won't be faster -- it only shines together with torch.compile