google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
33 stars 14 forks source link

Add page attention manager and kvcache manager #166

Closed FanhaiLu1 closed 1 month ago

FanhaiLu1 commented 1 month ago

This PR adds two classes, fundamental for page attention in JetStream:

PageAttentionManager:

This class manages and frees page resources, calculates page metadata, and supports cache insertion.

PageKVCacheGenerate:

This class updates decode caches in a page-attention format. Unlike the standard LLM KV cache shape ([batch_size, num_heads, seq_len, head_dim]), PageKVCache uses the shape [num_heads, total_num_pages, page_size, head_dim].