huggingface / transformers

šŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.44k stars 26.15k forks source link

KV cache with CPU offloading #30704

Closed n17s closed 1 month ago

n17s commented 3 months ago

Feature request

I would like to contribute a KV cache implementation that only keeps a couple of layers on the GPU: the current layer in the forward pass as well as prefetching the next layer. The KV cache for the rest of the layers are on the CPU which usually has much larger capacity. Once the forward pass is done with a layer, the corresponding KV cache is evicted back to the CPU. It will be fetched again during the generation of the next token. This is all implemented in this gist which can be used as a drop-in replacement for the transformers.cache_utils.DynamicCache class.

Motivation

Performing inference with large language models on very long contexts can easily run out of GPU memory. The vast majority of the memory is consumed by the KV cache. The KV cache is indexed by layer and accessed sequentially by layer index during the forward pass. This gives us both a natural segmentation of the KV cache by layer index as well as a clear pattern for prefetching and eviction. When we do the forward pass on layer k, we evict layer k-1 and prefetch layer k+1. Thus, there are always at most 2 layers of KV cache in GPU memory.

The overlap of I/O and compute is not perfect. For example, for a 7b parameter model on a single H100 I am getting 12 tokens/sec with my KV cache vs. 16 tokens/sec with the standard KV cache for a context size of 8k tokens. However, the standard KV cache goes out of memory at a context size of 128k tokens while mine still works.

Your contribution

I have provided an initial implementation in this gist. Not sure what is the right way to integrate it. Could be a new class, or an option for DynamicCache, or an automatic fallback when we are about to go out of memory, not sure. I'm also not very familiar with the PR process, since this is my first issue, so maybe it would be better if someone from the HF team can shepherd this through.

amyeroberts commented 3 months ago

cc @gante @ArthurZucker

ArthurZucker commented 3 months ago

Do you know if this fares well with the StaticCache? Otherwise very interesting! šŸ¤—

n17s commented 3 months ago

Sorry, I don't know if StaticCache can be supported in a similar way.

I should also mention a couple things. The operations from_legacy_cache, to_legacy_cache, and reorder_cache are not really tested in my workflow. I have seen modeling code that converts back and forth to legacy format on each generation step. My workflow does not but the conversion to / from legacy is a bit problematic because there's nowhere to store the original device in the legacy format. The current version ends up moving the key states back to their original device during this conversion, which is not ideal.

If we can assume that original device is always 'cuda' (instead of 'cuda:0' for some, 'cuda:1' for others, etc.) then we can eliminate the need to store the original device. However, I am not familiar with the codebase enough to tell if it is safe to assume this.

ArthurZucker commented 3 months ago

You should not need to test them! They are for Backward compatibility, and generate will soon not need that at all anymore.

IMO storing the device makes sense. It's also something that could be interesting for @muellerzr and @SunMarc, an accelerate friendly thing to handle that? šŸ¤—

n17s commented 3 months ago

I'm going to test reorder_cache today-ish and fix any issues.

I can make a PR afterwards but I need some guidance on how this should be integrated. New class? Option on the constructor of the existing DynamicCache? Something else?

ArthurZucker commented 3 months ago

Yep a new class sounds great!

n17s commented 3 months ago

As an update, I am a bit blocked because beam search has become extremely slow with this approach. Profiling suggests that torch.index_select operations on the CPU are very slow.

ArthurZucker commented 2 months ago

would transferring the tensor on device when this happens (in the cache class) not be efficient enough?

SunMarc commented 2 months ago

IMO storing the device makes sense. It's also something that could be interesting for @muellerzr and @SunMarc, an accelerate friendly thing to handle that? šŸ¤—

+1 on @ArthurZucker comments. I think it is better to handle it in this new class. I'm not sure about how to approach this in accelerate. I think it is best to stay within this new class.

n17s commented 2 months ago

would transferring the tensor on device when this happens (in the cache class) not be efficient enough?

Nah, too much back and forth. What seems to be working is to delay the operation until the tensors are back to their device. This is now implemented in the updated gist around lines 37 to 52

When I test with

model.generate(inputs["input_ids"], num_beams=4, num_beam_groups=2, num_return_sequences=4, diversity_penalty=1.0, max_new_tokens=50, early_stopping=True)

it produces the same results as using the original implementation, but curious to learn if I am breaking any assumptions by delaying the operation.

ArthurZucker commented 2 months ago

This looks good TBH šŸ”„ I think at this point it makes sense to open a PR with a OffloadedCache (or another name but tldr a new cache class)

n17s commented 2 months ago

Will do. Besides this class, are there any other changes I need to have in the PR?

ArthurZucker commented 2 months ago

I think that should be it!

n17s commented 2 months ago

I have started the PR, linked above. Also added the possibility to specify cache_implementation="offloaded" in GenerationConfig

gerbenvv commented 1 month ago

Some comments about the gist / PR:

I think you should try to:

if you want this to be very performant. Since then the copies will be trully non-blocking while when the CPU memory is not pinned it will not be. Also, marking the CPU memory as static will prevent recompiles.

I am also wondering why a CUDA stream is created but never awaited when the memory is actually used. This probably works since the copy is actually blocking, but it doesn't make too much sense to me. It will be needed if you make it actually non-blocking (use wait_stream ).

I will try to whip up a gist of my own.

ArthurZucker commented 1 month ago

Sounds good šŸ‘€

gerbenvv commented 1 month ago

Cool, nice work btw and I will start with a static cache implementation. It's a bit easier to do and will be a good stepping stone to creating a dynamic cache version that has non-blocking copies. Also it's nice if we have both options.

It might actually not be possible to be dynamic and have pinned memory on CPU/static memory on GPU but we'll see. If that's the case then it's still nice to have a static cache version of this.

gerbenvv commented 1 month ago

Okay, I whipped something up :grin:

Check https://gist.github.com/gerbenvv/282ed3c981a63ad71a301cdc1a705ef1

It's not finished yet, but a proof of concept:

It does:

ArthurZucker commented 1 month ago

Openeing a PR will make it easier for us to review if you want one!

gerbenvv commented 1 month ago

Yeah I will open one after I clean up the code a bit, have tested it with beamsearch and added some docstrings.

It will probabably have some conflicts with the existing MR from @n17s since it will touch some of the same code however.

I may try to make beamsearch faster by using the delayed indexing trick or that could be future work.

Like-wise, making some performance improvements on top @n17s work (making the copies non-blocking, not copying the whole cache to CPU) could be future work as well.

Curious to hear @n17s thoughts.

gerbenvv commented 1 month ago

Opened PR but note that it's a WIP, it doesn't work properly as of yet (still debugging). Will move it out of WIP when ready to be reviewed which will probably be sometime this week.

n17s commented 1 month ago

Curious to hear @n17s thoughts.

@gerbenvv

n17s commented 1 month ago

Also, I suggest starting a new issue and link/mention this issue. When my PR is merged, this issue will be closed. And edit your PR to be fixing the issue you open.

gerbenvv commented 1 month ago

Okay, I will create a new issue for this. Also:

n17s commented 1 month ago

I have not subclassed from StaticCache since it doesn't share any of the code with it

Understood. Subclassing is mostly for making code such as if isinstance(foo, StaticCache): work in both cases.