tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
377 stars 46 forks source link

Paged KV cache support #10100

Open cglagovichTT opened 1 month ago

cglagovichTT commented 1 month ago

This issue describes our proposal for paged KV cache in tt-metal models. First we will describe how vLLM implements the paged KV cache, then we will describe how this might be adapted for metal. Actions generated from this issue include updating paged_fill_cache, paged_update_cache, and paged_scaled_dot_product_attention ops.

vLLM paged KV cache

Let's consider vLLM's reshape_and_cache_kernel which updates the paged KV cache with new K and new V:

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel(
    const scalar_t* __restrict__ key,    // [num_tokens, num_heads, head_size]
    const scalar_t* __restrict__ value,  // [num_tokens, num_heads, head_size]
    cache_t* __restrict__ key_cache,     // [num_blocks, num_heads, head_size/x,
                                         // block_size, x]
    cache_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size,
                                         // block_size]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int key_stride, const int value_stride, const int num_heads,
    const int head_size, const int block_size, const int x,
    const float kv_scale) {
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }

  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
...

In this kernel, slot_mapping does the heavy lifting. Let's ignore how key_cache and value_cache have strange shapes - this comes down to optimizing memory layout for the paged_attention cuda kernel. Let's say key_cache has shape [num_blocks, num_heads, block_size, head_size]. The input slot_mapping is used to select blocks and offsets into blocks where key and value should be inserted. slot_mapping is a flat mapping into 0..num_blocks*block_size such that slot_idx / block_size is the block number and slot_idx % block_size is the offset. This is very simple in that it removes the need of a block_tables input.

Now let's look at how vLLM's attention kernel uses block_tables.

template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
          bool IS_BLOCK_SPARSE,
          int PARTITION_SIZE = 0>  // Zero means no partitioning.
__device__ void paged_attention_kernel(
    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
    float* __restrict__ max_logits,  // [num_seqs, num_heads,
                                     // max_num_partitions]
    scalar_t* __restrict__ out,  // [num_seqs, num_heads, max_num_partitions,
                                 // head_size]
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
    const int num_kv_heads,               // [num_heads]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
    const int blocksparse_vert_stride, const int blocksparse_block_size,
    const int blocksparse_head_sliding_step)

paged_attention_kernel uses block_tables as [num_seqs, max_num_blocks_per_seq] to map pages to the caches.

TT Paged KV

There will be some differences between our paged KV cache and vLLM's paged KV cache.

  1. We won't have different data layouts between K and V. vLLM's data layouts are motivated by cuda kernel optimizations, so we will not see benefit from the same layout.
  2. Our KV cache will be tilized, so tile indexing will be taken into account.

Let's say that our K and V cache will both be of shape [max_num_blocks, num_heads, block_size, head_dim]. The cache will be tilized so that it has max_num_blocks * num_heads * block_size/32 * head_dim/32 tiles, so each block contains num_heads * block_size/32 * head_dim/32 tiles.

Our K and V cache can adopt this layout.

image

Our block_tables will map from max_num_blocks_per_seq -> max_num_blocks, which gives a physical block number for a user's virtual block number. From this physical block number, a kernel can calculate tile indices into heads, block_dim, and head_dim.

To simplify the host side, we can pass block_tables into paged_update_cache, paged_fill_cache, and paged_sdpa, instead of generating a slot_table like vLLM does for their update cache.

paged_update_cache

# cache: Tensor [max_num_blocks, num_heads=1, block_size, head_dim]
# key: Tensor [1, b, 1[32], head_dim]
# block_tables: Tensor [b, max_num_blocks_per_seq]
# token_idxs: List [b]
# pagetable_metadata: {num_heads, block_size, head_dim}
paged_update_cache(cache, key, token_ids, block_tables, pagetable_metadata)

The reader will do something like

for u in range(batch_size):
  virtual_block_idx = token_idxs[u] // block_size
  physical_block_idx = block_tables[u, virtual_block_idx]
  block_start_id = physical_block_idx * num_heads * block_size_tiles * head_dim_tiles

  block_row_tile = (token_idxs[u] % block_size) // TILE_HEIGHT
  block_offset = block_row_tile * head_dim_tiles
  for w in range(head_dim_tiles):
    tile_id = block_start_id + block_offset + w
    read_tile(tile_id, cache, local_l1_ptr)
uaydonat commented 1 month ago

Thank you for the diagram.

I guess the other option for kv cache layout is [max_num_blocks, block_size, num_heads, head_dim]. Is this worse?

cglagovichTT commented 1 month ago

For attention, num_heads, block_size, head_dim would be more natural since we need to do Q: block_size, head_dim @ K: head_dim, block_size.

cglagovichTT commented 1 month ago

paged_update_cache is implemented on branch and used as follows:

    cachett = ttl.operations.primary.transformers.paged_update_cache(
        cachett, xt, [], update_idxs_tensor=cache_idxs_tt, page_table=page_table_tt
    )

More tasks:

cglagovichTT commented 1 month ago

Latest update: paged_fill_cache and paged_update_cache are finished. See their test cases for how we pass page tables and index tensors into these ops.

On the device side, we still need to add support for FlashDecode with paged KV cache. This will be a matter of modifying the reader kernel to look like the fill_cache writer.

On the host side, we will need to write up a simple test case which implements a paged allocator and an inference scheduler so we can demonstrate this functionality.