AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Add support for local sliding window attention in TPU splash_attention #830

Closed gagika closed 1 month ago

gagika commented 1 month ago

Added support for sliding window attention masking in TPU splash_attention