AI-Hypercomputer / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
41 stars 15 forks source link

Add page attention manager and kvcache manager #167

Closed FanhaiLu1 closed 3 months ago

FanhaiLu1 commented 3 months ago

This PR adds two classes, fundamental for page attention in JetStream:

PageAttentionManager:

This class manages and frees page resources, calculates page metadata, and supports cache insertion.

PageKVCacheGenerate:

This class updates decode caches in a page-attention format. Unlike the standard LLM KV cache shape ([batch_size, num_heads, seq_len, head_dim]), PageKVCache uses the shape [num_heads, total_num_pages, page_size, head_dim].