Open cglagovichTT opened 1 month ago
Thank you for the diagram.
I guess the other option for kv cache layout is [max_num_blocks, block_size, num_heads, head_dim]
. Is this worse?
For attention, num_heads, block_size, head_dim
would be more natural since we need to do Q: block_size, head_dim @ K: head_dim, block_size
.
paged_update_cache
is implemented on branch and used as follows:
cachett = ttl.operations.primary.transformers.paged_update_cache(
cachett, xt, [], update_idxs_tensor=cache_idxs_tt, page_table=page_table_tt
)
More tasks:
Latest update:
paged_fill_cache
and paged_update_cache
are finished. See their test cases for how we pass page tables and index tensors into these ops.
On the device side, we still need to add support for FlashDecode with paged KV cache. This will be a matter of modifying the reader kernel to look like the fill_cache writer.
On the host side, we will need to write up a simple test case which implements a paged allocator and an inference scheduler so we can demonstrate this functionality.
This issue describes our proposal for paged KV cache in tt-metal models. First we will describe how vLLM implements the paged KV cache, then we will describe how this might be adapted for metal. Actions generated from this issue include updating
paged_fill_cache
,paged_update_cache
, andpaged_scaled_dot_product_attention
ops.vLLM paged KV cache
Let's consider vLLM's reshape_and_cache_kernel which updates the paged KV cache with new K and new V:
In this kernel,
slot_mapping
does the heavy lifting. Let's ignore howkey_cache
andvalue_cache
have strange shapes - this comes down to optimizing memory layout for thepaged_attention
cuda kernel. Let's saykey_cache
has shape[num_blocks, num_heads, block_size, head_size]
. The inputslot_mapping
is used to select blocks and offsets into blocks wherekey
andvalue
should be inserted.slot_mapping
is a flat mapping into0..num_blocks*block_size
such thatslot_idx / block_size
is the block number andslot_idx % block_size
is the offset. This is very simple in that it removes the need of ablock_tables
input.Now let's look at how vLLM's attention kernel uses
block_tables
.paged_attention_kernel
usesblock_tables
as[num_seqs, max_num_blocks_per_seq]
to map pages to the caches.TT Paged KV
There will be some differences between our paged KV cache and vLLM's paged KV cache.
Let's say that our K and V cache will both be of shape
[max_num_blocks, num_heads, block_size, head_dim]
. The cache will be tilized so that it hasmax_num_blocks * num_heads * block_size/32 * head_dim/32
tiles, so each block containsnum_heads * block_size/32 * head_dim/32
tiles.Our K and V cache can adopt this layout.
Our
block_tables
will map frommax_num_blocks_per_seq -> max_num_blocks
, which gives a physical block number for a user's virtual block number. From this physical block number, a kernel can calculate tile indices into heads, block_dim, and head_dim.To simplify the host side, we can pass
block_tables
intopaged_update_cache
,paged_fill_cache
, andpaged_sdpa
, instead of generating aslot_table
like vLLM does for their update cache.paged_update_cache
The reader will do something like