microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.77k stars 2.94k forks source link

WIP: FlashAttention for WebGPU EP #22919

Open sushraja-msft opened 9 hours ago

sushraja-msft commented 9 hours ago

WIP: Implementation of FlashAttention that works for MHA

The other scenarios require more debugging, algorithm needs optimization as well for the 1 seq length case because workgroups are left unused in how ComputeDotProduct is invoked.