iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.6k stars 583 forks source link

Enable fusion of attention blocks (i.e. FlashAttention) #11780

Closed qedawkins closed 2 months ago

qedawkins commented 1 year ago

This issue is intended to offer a starting point for discussion on how to implement FlashAttention in IREE, as well as outline the lowering steps and give an IR example.

Brief Algorithm Overview

FlashAttention is an optimization for attention blocks that, at its core, is the fusion of matrix multiply + softmax + matrix multiply (i.e. AttentionHead = Softmax(QK)V for query, key, value matrices Q, K, V). This can also include layers like optional dropout between the softmax and second matmul. The algorithm for FlashAttention as described in its introductory paper is shown below (with an optional block-sparse matrix for an approximate version of the algorithm) image

Lowering outline & example IR

Enabling FlashAttention can be partitioned into the work required before and after dispatch formation. For the backend, we need matmul + softmax + matmul fused into a single dispatch, ideally in as simple a format as possible. In particular because the computation of the softmax in FlashAttention is affected by the tiling done in the algorithm. For the formation of the dispatch region, an outline of the lowering for attention blocks is shown here. Start with a PyTorch module containing a single MultiheadAttention layer.

import torch

embed_dim = 256 
num_heads = 8 
batch_size = 8 
sequence_length = 20
q = torch.randn(batch_size, sequence_length, embed_dim)
k = torch.randn(batch_size, sequence_length, embed_dim)
v = torch.randn(batch_size, sequence_length, embed_dim)

class ExampleAttention(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.attention = torch.nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, query, key, value):
        return self.attention(query, key, value)

model = ExampleAttention()
from shark.shark_importer import import_with_fx

mlir_module, func_name = import_with_fx(model, (q,k,v))

(the importing example here uses a combination of torch_fx + torch-mlir to get the linalg IR via Shark). Looking at the IR we get back, the attention computation coming from PyTorch gets decomposed to BMM + Softmax + BMM.

    %61 = torch.aten.bmm %56, %60 : !torch.tensor, !torch.tensor -> !torch.tensor
    %62 = torch.prim.ListConstruct %int20, %int8, %int8, %int8 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %63 = torch.aten._unsafe_view %61, %62 : !torch.tensor, !torch.list<int> -> !torch.tensor
    %64 = torch.aten._softmax %63, %int-1, %false : !torch.tensor, !torch.int, !torch.bool -> !torch.tensor
    %65 = torch.prim.ListConstruct %int20, %int8, %int8, %int8 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %66 = torch.aten.expand %64, %65, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor
    %67 = torch.prim.ListConstruct %int160, %int8, %int8 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %68 = torch.aten.view %66, %67 : !torch.tensor, !torch.list<int> -> !torch.tensor
    %69 = torch.prim.ListConstruct %int20, %int8, %int8, %int32 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %70 = torch.aten.expand %49, %69, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor
    %71 = torch.prim.ListConstruct %int160, %int8, %int32 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %72 = torch.aten.view %70, %71 : !torch.tensor, !torch.list<int> -> !torch.tensor
    %73 = torch.aten.bmm %68, %72 : !torch.tensor, !torch.tensor -> !torch.tensor

Then currently softmax gets further decomposed when going to Linalg

#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map7 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
module {
  func.func @attention(%arg0: tensor<160x8x32xf32>, %arg1: tensor<160x32x8xf32>, %arg2: tensor<160x8x32xf32>) -> tensor<160x8x32xf32> {
    %c0_i64 = arith.constant 0 : i64
    %cst_3 = arith.constant 0.000000e+00 : f32
    %cst_4 = arith.constant -3.40282347E+38 : f32
    %21 = tensor.empty() : tensor<160x8x8xf32>
    %22 = linalg.fill ins(%cst_3 : f32) outs(%21 : tensor<160x8x8xf32>) -> tensor<160x8x8xf32>
    %23 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<160x8x32xf32>, tensor<160x32x8xf32>) outs(%22 : tensor<160x8x8xf32>) -> tensor<160x8x8xf32>
    %expanded_26 = tensor.expand_shape %23 [[0, 1], [2], [3]] : tensor<160x8x8xf32> into tensor<20x8x8x8xf32>
    %24 = tensor.empty() : tensor<20x8x8x1xi64>
    %25 = linalg.fill ins(%c0_i64 : i64) outs(%24 : tensor<20x8x8x1xi64>) -> tensor<20x8x8x1xi64>
    %26 = tensor.empty() : tensor<20x8x8x1xf32>
    %27 = linalg.fill ins(%cst_4 : f32) outs(%26 : tensor<20x8x8x1xf32>) -> tensor<20x8x8x1xf32>
    %28:2 = linalg.generic {indexing_maps = [#map5, #map7, #map7], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%expanded_26 : tensor<20x8x8x8xf32>) outs(%27, %25 : tensor<20x8x8x1xf32>, tensor<20x8x8x1xi64>) {
    ^bb0(%in: f32, %out: f32, %out_31: i64):
      %46 = linalg.index 3 : index
      %47 = arith.index_cast %46 : index to i64
      %48 = arith.maxf %in, %out : f32
      %49 = arith.cmpf ogt, %in, %out : f32
      %50 = arith.select %49, %47, %out_31 : i64
      linalg.yield %48, %50 : f32, i64
    } -> (tensor<20x8x8x1xf32>, tensor<20x8x8x1xi64>)
    %29 = tensor.empty() : tensor<20x8x8x8xf32>
    %30 = linalg.generic {indexing_maps = [#map5, #map7, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_26, %28#0 : tensor<20x8x8x8xf32>, tensor<20x8x8x1xf32>) outs(%29 : tensor<20x8x8x8xf32>) {
    ^bb0(%in: f32, %in_31: f32, %out: f32):
      %46 = arith.subf %in, %in_31 : f32
      linalg.yield %46 : f32
    } -> tensor<20x8x8x8xf32>
    %31 = linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%30 : tensor<20x8x8x8xf32>) outs(%29 : tensor<20x8x8x8xf32>) {
    ^bb0(%in: f32, %out: f32):
      %46 = math.exp %in : f32
      linalg.yield %46 : f32
    } -> tensor<20x8x8x8xf32>
    %32 = linalg.fill ins(%cst_3 : f32) outs(%26 : tensor<20x8x8x1xf32>) -> tensor<20x8x8x1xf32>
    %33 = linalg.generic {indexing_maps = [#map5, #map7], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%31 : tensor<20x8x8x8xf32>) outs(%32 : tensor<20x8x8x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %46 = arith.addf %in, %out : f32
      linalg.yield %46 : f32
    } -> tensor<20x8x8x1xf32>
    %34 = linalg.generic {indexing_maps = [#map5, #map7, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%31, %33 : tensor<20x8x8x8xf32>, tensor<20x8x8x1xf32>) outs(%29 : tensor<20x8x8x8xf32>) {
    ^bb0(%in: f32, %in_31: f32, %out: f32):
      %46 = arith.divf %in, %in_31 : f32
      linalg.yield %46 : f32
    } -> tensor<20x8x8x8xf32>
    %collapsed_27 = tensor.collapse_shape %34 [[0, 1], [2], [3]] : tensor<20x8x8x8xf32> into tensor<160x8x8xf32>
    %35 = tensor.empty() : tensor<160x8x32xf32>
    %36 = linalg.fill ins(%cst_3 : f32) outs(%35 : tensor<160x8x32xf32>) -> tensor<160x8x32xf32>
    %37 = linalg.batch_matmul ins(%collapsed_27, %arg2 : tensor<160x8x8xf32>, tensor<160x8x32xf32>) outs(%35 : tensor<160x8x32xf32>) -> tensor<160x8x32xf32>
    return %37 : tensor<160x8x32xf32>
  }
}

This poses a problem for implementing FlashAttention as currently it would require inferring softmax in addition to identifying and fusing the surrounding matrix multiplications. Then the backend would need to be able to similarly interpret the fused dispatch as flash attention (exacerbated by optional dropout/masking/scaling not shown in the above IR). Because we should see the softmax coming from the frontend, having a named softmax op is one way to alleviate the challenge of identifying an attention block but still requires some form of specialized fusion.

Another potential solution is to add attention as a LinalgExt op. Note however that some models define their own attention blocks (e.g. HuggingFace BERT, CompVis Stable Diffusion), which makes it difficult to rely on seeing an incoming attention op, even if something like MultiheadAttention + internal op _native_multi_head_attention exists (which gets decomposed in forward passes by pytorch by default anyway). Moreover, to my knowledge there is no mhlo op for attention. Assuming we won't get an incoming op then we still end up needing the specialized fusion in addition to a regressive lowering to the attention op. This may still be a good way to prototype for backend work.

Goals + Tasks

This is a WIP section that can be updated based on discussion

Additional Resources

Here are a few resources I found useful while compiling this issue.

allieculp commented 1 year ago

@qedawkins @mattwalsh @powderluv I think we have a lot of follow up here - should we update?

powderluv commented 1 year ago

@harsh-nod fyi to update what is done here.

harsh-nod commented 1 year ago

At a very high level, we have favorable performance numbers on this vs Triton and are currently in the process of upstreaming patches into both IREE and MLIR. Happy to get into more details if required.

allieculp commented 1 year ago

@nicolasvasilache for visibility

antiagainst commented 2 months ago

We have flash attention support now.