Open fahadh4ilyas opened 2 months ago
To use the paged mode (flash-attn only), you first need a cache initialized with a batch size of 1 and a length which is some multiple of the page size. The page size is always 256 with the current version of flash-attn. Essentially this cache won't have a shape, just a total capacity.
PagedParams is constructed like so:
params = ExLlamaV2Attention.PagedParams(
batch_size = batch_size,
block_index = block_index,
cache_seqlens = cache_seqlens,
max_cache_seqlen = cache_seqlens.max().item(),
page_size = 256,
q_len = q_len,
)
batch_size
here is the actual size of your batch, even though you're using a flat cache. block_index
is an int tensor of shape (batch_size, max_num_pages)
which defines which pages in the cache to use for which sequences in the batch. It can be padded to the right with arbitrary values for pages you'll never get to.cache_seqlens
is an int tensor of shape (batch_size,)
determining where in each sequence the input IDs to the forward pass belong.q_len
is the length of whatever you're sending through the forward pass, typically one. input_ids to the model would therefore be shape (batch_size, q_len)
So say you have three sequences that are currently 10, 1025 and 320 tokens long, respectively, and you want room in the cache for each to grow by 500 tokens. You're forwarding a single token. That could look like:
batch_size:
3
block_index:
[
[ 0, 1, 0, 0, 0, 0 ], # positions 0:512in the cache, and some padding
[ 2, 3, 4, 5, 6, 7 ], # positions 512:2048
[ 8, 9, 10, 11, 12, 0 ] # positions 2048:3328+ padding
]
cache_seqlens:
[ 10, 1025, 320 ]
page_size:
256
q_len:
1
input_ids:
[
[token_a],
[token_b],
[token_c]
]
So when the forward pass writes the keys/values for position 10, it only touches page 0 in the cache. At the same time it will write position 512+1025, which goes to page 6, etc. It's the cache_seqlens
tensor that determines how long each past is and thereby which page to look up in the block index.
Now, there's some choices you could make about how to get to the above point in the first place. input_ids
is still always a rectangular tensor, so to prefill the initial 10, 1025 and 320 tokens you'd need to to three forward passes to avoid padding.
You could do one with a shape of (3, 10)
, then another with shape (2, 310)
and finally (1, 705)
.
Or you just do each sequence in the element as a bsz 1 forward pass. This is what the dynamic generator does and it simplifies things a lot, especially for continuous batching. I.e.:
prompt a:
batch_size: 1
block_index: [[0]]
cache_seqlens: [[0]]
q_len: 10
input_ids: tokenizer.encode(prompt_a)
prompt b:
batch_size: 1
block_index: [[2, 3, 4, 5, 6]]
cache_seqlens: [[0]]
q_len: 1025
input_ids: tokenizer.encode(prompt_b)
prompt c:
batch_size: 1
block_index: [[8, 9]]
cache_seqlens: [[0]]
q_len: 320
input_ids: tokenizer.encode(prompt_c)
There's a bunch of fun details about paged attention, such as the fact that the page indices don't need to be contiguous. Also they don't need to be unique, as long as you're not updating the same page twice in a forward pass. The dynamic generator uses both of those details for deduplication and continuous batching, respectively.
If you wanted to not have a predefined length max_new_tokens you could allocate pages dynamically during inference. There's nothing that prevents you from adding page 13 after page 1 in the first sequence, or growing the block_index
tensor by one column to add page 14 after page 7.
It does of course require some bookkeeping in your generator, and I'm not sure how well that plays together with HF and pipelines and whatnot.
Okay, I kind of get the concept. I think I want to forward each sequence as a bsz 1 forward pass. Does this means we have to do for-looping each sequence for one big batch forward pass? What about the cache instance? should I make one for each sequence or just make one for all? But, how the cache know which sequence is forwarded with it?
You use one cache for everything, and it's the block_index
tensor that says which pages in the cache are used for each sequence, whether you're doing them one at a time or batching.
One way to go about it would be to start by tokenizing all the prompts in a batch, then constructing the block index based on how many pages each sequence is going to need, including both the prompt and the completion:
block_index_batch:
[
[ 0, 1, 0, 0, 0, 0 ], # 10+500 tokens needs 2 pages
[ 2, 3, 4, 5, 6, 7 ], # 1025+500 tokens -> 6 pages
[ 8, 9, 10, 11, 0, 0 ] # 320+500 -> 4 pages
]
Then you run the three individual forward passes to prefill:
seq a: block_index = block_index_batch[0:1, :]
seq b: block_index = block_index_batch[1:2, :]
seq c: block_index = block_index_batch[2:3, :]
It doesn't matter if the block index has extra padding on the right, since it's indexed from the left. And then for each token you pass block_index_batch
so you can index into all three sequences at once.
I understand. But, I have another doubt. What about the input mask and position offset? For input mask might be solved because the masking process is done inside flash attention. But, what about position offset?
You wouldn't use masking or position offsets in paged mode, only a list of sequence lengths, and then the flash-attn kernel handles the rest. This allows all sequences to start at position zero (as long as that corresponds to a page boundary in the cache, as determined by block_index
) and have variable lengths as determined by cache_seqlens
.
So, I just create exllamav2 in HF format and it works well in batch. My code is in #606. Now, I got new problem. Bigger batch means bigger memory usage and mostly is for padding especially if there is different size in token sequence. Could you explain to me how exllamav2 paged attention works in code? I check the code in exllamav2/model.py,
PagedParams
is used but I don't know what to fill into the parameter.