WIP: Implementation of FlashAttention that works for MHA
Currently only works on machines where the subgroup size is the same as tile size. (Intel)
Works only for the condition of new sequence length is 1.
The other scenarios require more debugging, algorithm needs optimization as well for the 1 seq length case because workgroups are left unused in how ComputeDotProduct is invoked.
WIP: Implementation of FlashAttention that works for MHA
The other scenarios require more debugging, algorithm needs optimization as well for the 1 seq length case because workgroups are left unused in how ComputeDotProduct is invoked.