turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.64k stars 279 forks source link

How to implement paged attention in HF format? #616

Open fahadh4ilyas opened 2 months ago

fahadh4ilyas commented 2 months ago

So, I just create exllamav2 in HF format and it works well in batch. My code is in #606. Now, I got new problem. Bigger batch means bigger memory usage and mostly is for padding especially if there is different size in token sequence. Could you explain to me how exllamav2 paged attention works in code? I check the code in exllamav2/model.py, PagedParams is used but I don't know what to fill into the parameter.

turboderp commented 2 months ago

To use the paged mode (flash-attn only), you first need a cache initialized with a batch size of 1 and a length which is some multiple of the page size. The page size is always 256 with the current version of flash-attn. Essentially this cache won't have a shape, just a total capacity.

PagedParams is constructed like so:

params = ExLlamaV2Attention.PagedParams(
    batch_size = batch_size,
    block_index = block_index,
    cache_seqlens = cache_seqlens,
    max_cache_seqlen = cache_seqlens.max().item(),
    page_size = 256,
    q_len = q_len,
)

So say you have three sequences that are currently 10, 1025 and 320 tokens long, respectively, and you want room in the cache for each to grow by 500 tokens. You're forwarding a single token. That could look like:

batch_size: 
    3

block_index:
    [
        [  0, 1,  0,  0,  0, 0 ],   # positions 0:512in the cache, and some padding
        [  2, 3,  4,  5,  6, 7 ],   # positions 512:2048
        [  8, 9, 10, 11, 12, 0 ]    # positions 2048:3328+ padding
    ]

cache_seqlens:
    [ 10, 1025, 320 ]

page_size:
    256

q_len:
    1

input_ids:
    [ 
        [token_a],
        [token_b],
        [token_c]
    ]

So when the forward pass writes the keys/values for position 10, it only touches page 0 in the cache. At the same time it will write position 512+1025, which goes to page 6, etc. It's the cache_seqlens tensor that determines how long each past is and thereby which page to look up in the block index.

Now, there's some choices you could make about how to get to the above point in the first place. input_ids is still always a rectangular tensor, so to prefill the initial 10, 1025 and 320 tokens you'd need to to three forward passes to avoid padding.

You could do one with a shape of (3, 10), then another with shape (2, 310) and finally (1, 705).

Or you just do each sequence in the element as a bsz 1 forward pass. This is what the dynamic generator does and it simplifies things a lot, especially for continuous batching. I.e.:

prompt a:
    batch_size: 1
    block_index: [[0]]
    cache_seqlens: [[0]]
    q_len: 10
    input_ids: tokenizer.encode(prompt_a)

prompt b:
    batch_size: 1
    block_index: [[2, 3, 4, 5, 6]]
    cache_seqlens: [[0]]
    q_len: 1025
    input_ids: tokenizer.encode(prompt_b)

prompt c:
    batch_size: 1
    block_index: [[8, 9]]
    cache_seqlens: [[0]]
    q_len: 320
    input_ids: tokenizer.encode(prompt_c)

There's a bunch of fun details about paged attention, such as the fact that the page indices don't need to be contiguous. Also they don't need to be unique, as long as you're not updating the same page twice in a forward pass. The dynamic generator uses both of those details for deduplication and continuous batching, respectively.

If you wanted to not have a predefined length max_new_tokens you could allocate pages dynamically during inference. There's nothing that prevents you from adding page 13 after page 1 in the first sequence, or growing the block_index tensor by one column to add page 14 after page 7.

It does of course require some bookkeeping in your generator, and I'm not sure how well that plays together with HF and pipelines and whatnot.

fahadh4ilyas commented 2 months ago

Okay, I kind of get the concept. I think I want to forward each sequence as a bsz 1 forward pass. Does this means we have to do for-looping each sequence for one big batch forward pass? What about the cache instance? should I make one for each sequence or just make one for all? But, how the cache know which sequence is forwarded with it?

turboderp commented 2 months ago

You use one cache for everything, and it's the block_index tensor that says which pages in the cache are used for each sequence, whether you're doing them one at a time or batching.

One way to go about it would be to start by tokenizing all the prompts in a batch, then constructing the block index based on how many pages each sequence is going to need, including both the prompt and the completion:

block_index_batch:
    [
        [  0, 1,  0,  0,  0, 0 ],  # 10+500 tokens needs 2 pages
        [  2, 3,  4,  5,  6, 7 ],  # 1025+500 tokens -> 6 pages
        [  8, 9, 10, 11, 0, 0 ]  # 320+500 -> 4 pages
    ]

Then you run the three individual forward passes to prefill:

seq a: block_index = block_index_batch[0:1, :]
seq b: block_index = block_index_batch[1:2, :]
seq c: block_index = block_index_batch[2:3, :]

It doesn't matter if the block index has extra padding on the right, since it's indexed from the left. And then for each token you pass block_index_batch so you can index into all three sequences at once.

fahadh4ilyas commented 2 months ago

I understand. But, I have another doubt. What about the input mask and position offset? For input mask might be solved because the masking process is done inside flash attention. But, what about position offset?

turboderp commented 1 month ago

You wouldn't use masking or position offsets in paged mode, only a list of sequence lengths, and then the flash-attn kernel handles the rest. This allows all sequences to start at position zero (as long as that corresponds to a page boundary in the cache, as determined by block_index) and have variable lengths as determined by cache_seqlens.