ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 86 forks source link

PagedAttention support in MIGraphX #3588

Open gyulaz-htec opened 1 week ago

gyulaz-htec commented 1 week ago

The MLPERF team is interested in MIGraphX for LLama2 inference. Currently we're using vLLM for LLama2 ineference which uses PagedAttention(PA) and continuous batching to achieve better performance than the current static batcher implementations (HuggingFace Optimum, TensorRT). However vLLM is written in python and that is an overhead for us. We would like to use a lower level API and MIGraphX could satisfy that requirement. We're aware that the MIGraphX team is currently working enabling LLama2 with GQA with a quantized model. We would like to know, if there are any plans to implement PA, or the current GQA support should be comparable with that in terms of performance? Does PA would fit in the MIGraphX feature set? If yes, can someone give an estimat how big of a work would be adding PA to the project? Are there any blockers regarding implementing PA?

We could use this issue to discuss the possibility of PA in MIGraphX and track the different opinions/ideas. cc @causten @TedThemistokleous @pfultz2 @turneram @attila-dusnoki-htec @ototh-htec