Closed qedawkins closed 2 months ago
@qedawkins @mattwalsh @powderluv I think we have a lot of follow up here - should we update?
@harsh-nod fyi to update what is done here.
At a very high level, we have favorable performance numbers on this vs Triton and are currently in the process of upstreaming patches into both IREE and MLIR. Happy to get into more details if required.
@nicolasvasilache for visibility
We have flash attention support now.
This issue is intended to offer a starting point for discussion on how to implement FlashAttention in IREE, as well as outline the lowering steps and give an IR example.
Brief Algorithm Overview
FlashAttention is an optimization for attention blocks that, at its core, is the fusion of matrix multiply + softmax + matrix multiply (i.e.
AttentionHead = Softmax(QK)V
for query, key, value matrices Q, K, V). This can also include layers like optional dropout between the softmax and second matmul. The algorithm for FlashAttention as described in its introductory paper is shown below (with an optional block-sparse matrix for an approximate version of the algorithm)Lowering outline & example IR
Enabling FlashAttention can be partitioned into the work required before and after dispatch formation. For the backend, we need matmul + softmax + matmul fused into a single dispatch, ideally in as simple a format as possible. In particular because the computation of the softmax in FlashAttention is affected by the tiling done in the algorithm. For the formation of the dispatch region, an outline of the lowering for attention blocks is shown here. Start with a PyTorch module containing a single MultiheadAttention layer.
(the importing example here uses a combination of torch_fx + torch-mlir to get the linalg IR via Shark). Looking at the IR we get back, the attention computation coming from PyTorch gets decomposed to BMM + Softmax + BMM.
Then currently softmax gets further decomposed when going to Linalg
This poses a problem for implementing FlashAttention as currently it would require inferring softmax in addition to identifying and fusing the surrounding matrix multiplications. Then the backend would need to be able to similarly interpret the fused dispatch as flash attention (exacerbated by optional dropout/masking/scaling not shown in the above IR). Because we should see the softmax coming from the frontend, having a named softmax op is one way to alleviate the challenge of identifying an attention block but still requires some form of specialized fusion.
Another potential solution is to add attention as a LinalgExt op. Note however that some models define their own attention blocks (e.g. HuggingFace BERT, CompVis Stable Diffusion), which makes it difficult to rely on seeing an incoming attention op, even if something like MultiheadAttention + internal op _native_multi_head_attention exists (which gets decomposed in forward passes by pytorch by default anyway). Moreover, to my knowledge there is no mhlo op for attention. Assuming we won't get an incoming op then we still end up needing the specialized fusion in addition to a regressive lowering to the attention op. This may still be a good way to prototype for backend work.
Goals + Tasks
This is a WIP section that can be updated based on discussion
Additional Resources
Here are a few resources I found useful while compiling this issue.