plaidml / tpp-mlir

TPP experimentation on MLIR for linear algebra
https://arxiv.org/abs/2404.15204
Other
111 stars 30 forks source link

Enable running softmax with TPPs #584

Open chelini opened 1 year ago

chelini commented 1 year ago

To enable running softmax with TPPs we need more operations:

  1. max/sum reduce op (%2 and %8)
  2. sub operation (%5) We needs to support broadcast semantics in sub or implement an explicit broadcast op see %4.
  3. exp (%6)
  4. div (%10)

The IR below shows a Softmax example in Linalg, extracted from a self-attention layer. The lowering is: TF dialect -> StableHLO -> Linalg IR. To lower from TF dialect to StableHLO we use tf-opt while from StableHLO to linalg we use the IREE compiler and print after iree-stablehlo-to-iree-input.

The dimension of arg0 are: [B, heads, T, T] where B is the batched dimension, heads is the number of heads, while T is the sequence length.

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
module {
  func.func @softmax(%arg0: tensor<64x8x5x5xf32>) -> tensor<64x8x5x5xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant 0xFF800000 : f32
    %0 = tensor.empty() : tensor<64x8x5xf32>
    %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<64x8x5xf32>) -> tensor<64x8x5xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<64x8x5x5xf32>) outs(%1 : tensor<64x8x5xf32>) {
    ^bb0(%in: f32, %out: f32):
      %11 = arith.maxf %out, %in : f32
      linalg.yield %11 : f32
    } -> tensor<64x8x5xf32>
    %expanded = tensor.expand_shape %2 [[0], [1], [2, 3]] : tensor<64x8x5xf32> into tensor<64x8x5x1xf32>
    %3 = tensor.empty() : tensor<64x8x5x5xf32>
    %4 = linalg.generic {indexing_maps = [#map2, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded : tensor<64x8x5x1xf32>) outs(%3 : tensor<64x8x5x5xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x8x5x5xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %4 : tensor<64x8x5x5xf32>, tensor<64x8x5x5xf32>) outs(%3 : tensor<64x8x5x5xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %11 = arith.subf %in, %in_2 : f32
      linalg.yield %11 : f32
    } -> tensor<64x8x5x5xf32>
    %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<64x8x5x5xf32>) outs(%3 : tensor<64x8x5x5xf32>) {
    ^bb0(%in: f32, %out: f32):
      %11 = math.exp %in : f32
      linalg.yield %11 : f32
    } -> tensor<64x8x5x5xf32>
    %7 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x8x5xf32>) -> tensor<64x8x5xf32>
    %8 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%6 : tensor<64x8x5x5xf32>) outs(%7 : tensor<64x8x5xf32>) {
    ^bb0(%in: f32, %out: f32):
      %11 = arith.addf %out, %in : f32
      linalg.yield %11 : f32
    } -> tensor<64x8x5xf32>
    %expanded_1 = tensor.expand_shape %8 [[0], [1], [2, 3]] : tensor<64x8x5xf32> into tensor<64x8x5x1xf32>
    %9 = linalg.generic {indexing_maps = [#map2, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1 : tensor<64x8x5x1xf32>) outs(%3 : tensor<64x8x5x5xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x8x5x5xf32>
    %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6, %9 : tensor<64x8x5x5xf32>, tensor<64x8x5x5xf32>) outs(%3 : tensor<64x8x5x5xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %11 = arith.divf %in, %in_2 : f32
      linalg.yield %11 : f32
    } -> tensor<64x8x5x5xf32>
    return %10 : tensor<64x8x5x5xf32>
  }
}

Related to #414.

TBD:

rengolin commented 1 year ago

This IR looks similar to what we generate from mlir-gen, but the shapes are weird (5x5), where is that from?

The MHA softmax is a little different, we need both styles covered.

The other option is to use linalg fusion on tensor but limit the pass only to the softmax operations to avoid having to split the body of the generic later on before mapping to tpps.

softmax in libxsmm is lowered as an equation, and just calling the kernels one after another is very close to optimal. I would not create complicated machinery that is specific to certain complex patterns unless the benefit was very large and there was no other way.

Softmax will eventually be lowered as an equation, which is the right way long term, so we can live with most of the performance now and the rest later.

chelini commented 1 year ago

Yes, calling the kernel one after the other would be the plan. Still, we must either fuse along 64 and 8 to extract 2d tensors or materialize the two outermost dimensions for each linalg ops and replace the body with a tpp operation. Do you have an example of the IR generated by mlir-gen? 5 is an arbitrary number for the sequence length, it does not matter in this context.

rengolin commented 1 year ago

Do you have an example of the IR generated by mlir-gen

Yup. just run mlir-gen and you'll see.

Also, just to be clear, this is really low priority. Finding the right shapes for MHA and finally getting TPP on tensors in the main pipeline are still the most important tasks right now.