AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.53k stars 293 forks source link

Adds ragged attention. #835

Closed patemotter closed 2 months ago

patemotter commented 2 months ago

Adds Ragged Attention Pallas kernel as an option when performing autoregressive attention. By default this is disabled and does not interfere with the existing AR attention. It can be enabled by the CLI argument use_ragged_attention=true.

Improvements are most noticeable when