Open oobabooga opened 2 months ago
Some useful resources from the docs:
@oobabooga the links @zucchini-nlp shared should enable what you want with respect to having a static memory. I'm not 100% sure StaticCache
works with all quantization types, as they sometimes rewrite parts of the forward pass. Let us know if the information is not clear or if you find issues 🤗
Also on the topic of memory: we've merged two PRs recently that should lower memory requirements, regardless of the cache type. No action required from a user's point of view:
I actually see quite big difference in peak memory usage between 4.42.3 and 4.44.2, and also 4.44.2 is 20-30 percent faster.
The documentation is not clear. The first link recommends doing
model.generation_config.cache_implementation = "static"
The second one recommends passing a kwarg to model.generate
:
# simply pass the cache implementation="static"
out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="static")
The third one defines StaticCache
and provides a forward pass example, but it doesn't concern .generate
.
How do I explicitly specify the maximum sequence length for the static cache before using .generate
and before the cache memory gets allocated?
@oobabooga
cache_implementation
creates a cache whose maximum size is defined by the prompt length and max_new_tokens
. If you call it repeatedly but e.g. increase max_new_tokens
, it will instantiate a new (larger) cache. You can see the cache_implementation
argument as a low-effort, low-flexibility flag.
How do I explicitly specify the maximum sequence length for the static cache before using .generate and before the cache memory gets allocated?
For complete flexibility, you should instantiate a Cache
object outside generate
and pass it through the past_key_values
argument (as in this example). This is the best solution if you expect different generate
to be called with different input sizes and max_new_tokens
values :)
regarding difficulty in parsing information: what do you think would help? Perhaps separating the docs into basic and advanced usage?
So allocations only happen once per .generate
call -- that's good enough, my concern was with repeated memory allocations for each token. Thanks for the clarification.
I have tested the speed with/without static cache for a 7B model, and could not find a speed improvement by using cache_implementation="static"
by itself, so I assume the main benefit is when this is combined with torch.compile
.
Before
Prompt processing (t/s): 7628.15
Text generation (t/s): 47.63
After
Prompt processing (t/s): 7256.68
Text generation (t/s): 42.26
regarding difficulty in parsing information: what do you think would help? Perhaps separating the docs into basic and advanced usage?
I think that the exact clarifications in your last comment would help if included in the documentation.
My main interest in static cache is the fact that when ExLlama (v1) got introduced, I could fit a lot more context in my 3090 than with AutoGPTQ + transformers. But maybe the excess memory usage was unrelated to static cache, as I don't see a VRAM difference with or without cache_implementation="static"
. For reference, this is how static cache is used in the ExLlamaV2 generator:
Yep, if you want to fit longer context I would recommend the OffloadedStaticCache
! see here
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@oobabooga yeah, StaticCache
by itself won't be faster -- it only shines together with torch.compile
I see that there are many PRs about StaticCache, but I couldn't find a clear documentation on how to use it.
What I want
To not have Transformers allocate memory dynamically for the KV cache when using
model.generate()
, as that leads to increased memory usage (due to garbage collection not happening fast/often enough) and worse performance.To use that by default always, for every model, for every supported quantization backend (AutoAWQ, AutoGPTQ, AQLM, bitsandbytes, etc).
Who can help?
Maybe @gante