iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.48k stars 553 forks source link

[DT][Fusion] Move set encoding after forming dispatch region #17718

Open hanhanW opened 2 weeks ago

hanhanW commented 2 weeks ago

Learning from the past experience and the needs of quantization, we want to explore moving SetEncoding pass right after flow dispatch formation. At this point, we know that what the actual dispatch looks like. E.g., the dequant/broadcast ops and the consumers are already formed into matmul dispatch. We can set encoding based on those dispatches. The encoding now carries padding semantics, so it also provides a chance to cleanup all padding and query_upper_tile_size tech debts. Also, this will enable fusion for data-tiling. E.g., it enables mmt4d on CPU side.

TBD: Add subtasks for the issue. There are an amount of work in the issue.

hanhanW commented 1 week ago

I wrote down what should happen step by step here, and I'm using the below MLIR input as example:

util.func public @broadcasting_dequant_op(%arg0 : tensor<?x?xi8>, %rhs : tensor<?x?x?xi32>) -> tensor<?x?x?xi32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %d1 = tensor.dim %arg0, %c0 : tensor<?x?xi8>
  %d2 = tensor.dim %arg0, %c1 : tensor<?x?xi8>
  %d0 = tensor.dim %rhs, %c0 : tensor<?x?x?xi32>
  %empty = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xi32>
  %dequant = linalg.generic {
      indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
                       affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
      iterator_types = ["parallel", "parallel", "parallel"]}
      ins(%arg0 : tensor<?x?xi8>) outs(%empty : tensor<?x?x?xi32>) {
    ^bb0(%in: i8, %out: i32):
      %12 = arith.extui %in : i8 to i32
      linalg.yield %12 : i32
    } -> tensor<?x?x?xi32>

  %B = tensor.dim %dequant, %c0 : tensor<?x?x?xi32>
  %M = tensor.dim %dequant, %c1 : tensor<?x?x?xi32>
  %N = tensor.dim %rhs, %c1 : tensor<?x?x?xi32>
  %init = tensor.empty(%B, %M, %N) : tensor<?x?x?xi32>
  %cst = arith.constant 0 : i32
  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
  %op = linalg.batch_matmul_transpose_b
      ins(%dequant, %rhs : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
      outs(%fill : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
  util.return %op : tensor<?x?x?xi32>
}

After FormDispatchRegionsPass, the dequant ops are still outside of dispatch region.

util.func public @broadcasting_dequant_op(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @broadcasting_dequant_op(%input0: tensor<?x?xi8>, %input1: tensor<?x?x?xi32>) -> (%output0: tensor<?x?x?xi32>)"}} {
  %c0_i32 = arith.constant 0 : i32
  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?xi8>{%0, %1}
  %3 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
  %4 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
  %5 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[2] : index
  %6 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<?x?x?xi32>{%3, %4, %5}
  %7 = tensor.empty(%3, %0, %1) : tensor<?x?x?xi32>
  %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<?x?xi8>) outs(%7 : tensor<?x?x?xi32>) {
  ^bb0(%in: i8, %out: i32):
    %13 = arith.extui %in : i8 to i32
    linalg.yield %13 : i32
  } -> tensor<?x?x?xi32>
  %9 = tensor.empty(%3, %0, %4) : tensor<?x?x?xi32>
  %10 = linalg.fill ins(%c0_i32 : i32) outs(%9 : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
  %11 = flow.dispatch.region -> (tensor<?x?x?xi32>{%3, %0, %4}) {
    %13 = linalg.batch_matmul_transpose_b ins(%8, %6 : tensor<?x?x?xi32>, tensor<?x?x?xi32>) outs(%10 : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
    flow.return %13 : tensor<?x?x?xi32>
  }
  %12 = hal.tensor.export %11 "output0" : tensor<?x?x?xi32>{%3, %0, %4} -> !hal.buffer_view
  util.return %12 : !hal.buffer_view
}

Then we run CloneProducersIntoDispatchRegionsPass, which moves the deqaunt ops into dispatch region. The outside one is converted to a dispatch region, but it will be killed later. Let's ignore it for now. So the IR would be

#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  util.func public @broadcasting_dequant_op(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @broadcasting_dequant_op(%input0: tensor<?x?xi8>, %input1: tensor<?x?x?xi32>) -> (%output0: tensor<?x?x?xi32>)"}} {
    %c0_i32 = arith.constant 0 : i32
    %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
    %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
    %2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?xi8>{%0, %1}
    %3 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
    %4 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
    %5 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[2] : index
    %6 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<?x?x?xi32>{%3, %4, %5}
    %7 = flow.dispatch.region -> (tensor<?x?x?xi32>{%3, %0, %4}) {
      %9 = tensor.empty(%3, %0, %4) : tensor<?x?x?xi32>
      %c0_i32_0 = arith.constant 0 : i32
      %10 = tensor.empty(%3, %0, %1) : tensor<?x?x?xi32>
      %11 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<?x?xi8>) outs(%10 : tensor<?x?x?xi32>) {
      ^bb0(%in: i8, %out: i32):
        %14 = arith.extui %in : i8 to i32
        linalg.yield %14 : i32
      } -> tensor<?x?x?xi32>
      %12 = linalg.fill ins(%c0_i32_0 : i32) outs(%9 : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
      %13 = linalg.batch_matmul_transpose_b ins(%11, %6 : tensor<?x?x?xi32>, tensor<?x?x?xi32>) outs(%12 : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
      flow.return %13 : tensor<?x?x?xi32>
    }
    %8 = hal.tensor.export %7 "output0" : tensor<?x?x?xi32>{%3, %0, %4} -> !hal.buffer_view
    util.return %8 : !hal.buffer_view
  }
}

I think this is where SetEncoding should happen. The result of SetEncoding would be something like below. We look through the producers of matmul input operands, and set encodings on the source of it outside the dispatch. And we set encodings on inits within the dispatch region. The set_encoding ops outside the dispatch region will be handled by the later pass (e.g., "CloneEncodingsIntoProducerPass").

For init part, the empty->set_encoding will be folded into tensor.empty with encodings. The empty->fill->set_encoding will be folded to empty->fill by the pattern.

util.func public @broadcasting_dequant_op(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @broadcasting_dequant_op(%input0: tensor<?x?xi8>, %input1: tensor<?x?x?xi32>) -> (%output0: tensor<?x?x?xi32>)"}} {
  %c0_i32 = arith.constant 0 : i32
  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?xi8>{%0, %1}
  %3 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
  %4 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
  %5 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[2] : index
  %6 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<?x?x?xi32>{%3, %4, %5}
  // We'll need role/indexOperand, element_types, original_type?, and indexing maps
  // in the encoding.
  %2_with_encoding = iree_encoding.set_encoding %2
    : tensor<?x?xi8> -> tensor<?x?xi8, #iree_encoding.encoding<...>>
  %6_with_encoding = iree_encoding.set_encoding %2
    : tensor<?x?x?xi32> -> tensor<?x?x?i32, #iree_encoding.encoding<...>>

  %10 = flow.dispatch.region -> (tensor<?x?x?xi32>{%3, %0, %4}) {
    %12 = tensor.empty(%3, %0, %4) : tensor<?x?x?xi32>
    %12_with_encoding = iree_encoding.set_encoding %12
      : tensor<?x?x?xi32> -> tensor<?x?x?xi32>, #iree_encoding.encoding<...>>
    %c0_i32_0 = arith.constant 0 : i32
    %13 = tensor.empty(%3, %0, %1) : tensor<?x?x?xi32>
    // the role is LHS
    %13_with_encoding = iree_encoding.set_encoding %13
      : tensor<?x?x?xi32> -> tensor<?x?x?xi32>, #iree_encoding.encoding<...>>
    %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]}
      ins(%2_with_encoding : tensor<?x?xi8, #iree_encoding.encoding<...>>)
      outs(%13_with_encoding : tensor<?x?x?xi32, #iree_encoding.encoding<...>>) {
    ^bb0(%in: i8, %out: i32):
      %17 = arith.extui %in : i8 to i32
      linalg.yield %17 : i32
    } -> tensor<?x?x?xi32>
    %15 = linalg.fill ins(%c0_i32_0 : i32) outs(%12_with_encoding : tensor<?x?x?xi32, #iree_encoding.encoding<...>>) -> tensor<?x?x?xi32, #iree_encoding.encoding<...>>
    %15_with_encoding = iree_encoding.set_encoding %15
      : tensor<?x?x?xi32> -> tensor<?x?x?xi32>, #iree_encoding.encoding<...>>

    %16 = linalg.batch_matmul_transpose_b
      ins(%14, %6_with_encoding : tensor<?x?x?xi32, #iree_encoding.encoding<...>>, tensor<?x?x?xi32, #iree_encoding.encoding<...>>)
      outs(%15_with_encoding : tensor<?x?x?xi32, #iree_encoding.encoding<...>>) -> tensor<?x?x?xi32, #iree_encoding.encoding<...>>

    %16_unset_encoding = iree_encoding.unset_encoding %16 ...
    flow.return %16_unset_encoding : tensor<?x?x?xi32>
  }
  %11 = hal.tensor.export %10 "output0" : tensor<?x?x?xi32>{%3, %0, %4} -> !hal.buffer_view
  util.return %11 : !hal.buffer_view
}

Then we implement CloneEncodingsIntoProducerPass pass, which either moves the set_encoding op to the producer dispatch region or wrap it into a new dispatch region in-place. The reason that I don't move encoding to the previous dispatch is that it could be complicated. Because

  1. The producer region could have other uses. In this case, we might don't want to move it into the dispatch.
  2. We'll need to update the return type of the producer dispatch region op, which makes the SetEncoding pass more complicated.

@MaheshRavishankar @Max191 does it make sense to you? We can chat more details tomorrow or some time this week.

MaheshRavishankar commented 1 week ago

Yes, this makes sense. A Small nit is that you are still moving the encodings into the producer dispatch. Doing it as a separate pass makes sense.

Max191 commented 1 week ago

That looks similar to what I had in mind. The one difference is that I was thinking SetEncoding should only be responsible for setting the encodings on the target ops (i.e., not dequant ops). Then a separate pass will propagate the encoding ops and possibly also form dispatch regions. The reason I think this is better is that it could make SetEncoding complicated if including all the logic for setting encodings on dequant ops.

I'm actually still not entirely sure how to propagate the encodings past dequant-like ops, since they may not necessarily be unary like the one in your example IR. Dequant-like ops could be doing real dequantization, meaning they have scale and zero point inputs along with the quantized values. Setting encoding on this op means we have to set encoding on the scales and zero points as well. This leads to 2 issues:

  1. The scales and zero points have a single value for each quantization group, meaning that they have 1 fewer dimensions than the quantized tensor. We may need some extra representation in encodings to capture which dimensions in the encoded tensor are present/correspond to which dimensions in the encoding user (i.e., matmul). Broadcast ops have a similar problem, since the inputs to broadcasts are also missing a dimension. Although this is more long term, propagation through transposes would most likely need this as well.
  2. Dequantization ops with zero points cause issues with padding. For example, say there is a tensor<2x30xi4> quantized tensor, with a group size of 30. The scales and zero points will be tensor<2xi32>. If we pack the 30 dimension of the quantized values with an inner tile of 16, then we get tensor<2x2x16xi4> and tensor<2xi32>. The padded dimensions of the quantized tensor will now use the non-zero zero points from their tiles, which could create non-zero values in the dequantized result along the padded border.

(1) is probably a prerequisite to these flow changes, since we will not be able to propagate through dequant ops without it. After some discussion with @hanhanW, the current idea is to add a projected permutation map to the encoding, which indicates which values in the tensor correspond to the dimensions of the originalType.

(2) is more of a codegen/materialization issue, so it is not blocking any progress on flow level changes, but it is something we need to be careful with. I will create a separate issue about this tomorrow.

MaheshRavishankar commented 1 week ago

That looks similar to what I had in mind. The one difference is that I was thinking SetEncoding should only be responsible for setting the encodings on the target ops (i.e., not dequant ops). Then a separate pass will propagate the encoding ops and possibly also form dispatch regions. The reason I think this is better is that it could make SetEncoding complicated if including all the logic for setting encodings on dequant ops.

I'm actually still not entirely sure how to propagate the encodings past dequant-like ops, since they may not necessarily be unary like the one in your example IR. Dequant-like ops could be doing real dequantization, meaning they have scale and zero point inputs along with the quantized values. Setting encoding on this op means we have to set encoding on the scales and zero points as well. This leads to 2 issues:

  1. The scales and zero points have a single value for each quantization group, meaning that they have 1 fewer dimensions than the quantized tensor. We may need some extra representation in encodings to capture which dimensions in the encoded tensor are present/correspond to which dimensions in the encoding user (i.e., matmul). Broadcast ops have a similar problem, since the inputs to broadcasts are also missing a dimension. Although this is more long term, propagation through transposes would most likely need this as well.

Yeah, the encoding needs to represent how this data is accessed in the eventual operation... but at the point you are setting the encoding you can encode that logic into the indexing maps.

  1. Dequantization ops with zero points cause issues with padding. For example, say there is a tensor<2x30xi4> quantized tensor, with a group size of 30. The scales and zero points will be tensor<2xi32>. If we pack the 30 dimension of the quantized values with an inner tile of 16, then we get tensor<2x2x16xi4> and tensor<2xi32>. The padded dimensions of the quantized tensor will now use the non-zero zero points from their tiles, which could create non-zero values in the dequantized result along the padded border.

This is implicitly doing pad propagation during encoding. So again, we will have to be careful what the padding value we want to use will be....

(1) is probably a prerequisite to these flow changes, since we will not be able to propagate through dequant ops without it. After some discussion with @hanhanW, the current idea is to add a projected permutation map to the encoding, which indicates which values in the tensor correspond to the dimensions of the originalType.

I hope what I said above is consistent with what you meant here.

(2) is more of a codegen/materialization issue, so it is not blocking any progress on flow level changes, but it is something we need to be careful with. I will create a separate issue about this tomorrow.

Max191 commented 1 week ago

Yeah, the encoding needs to represent how this data is accessed in the eventual operation... but at the point you are setting the encoding you can encode that logic into the indexing maps.

I don't think this is possible. The user_indexing_maps are used to determine the contraction dimensions of the operation. In cases where the encoding is set on the input of a broadcast, some of the dimensions would be missing, and the contractions dims would not be properly inferred. We need a separate indexing map to show how the encoded tensor's dims map to the originalType (i.e., the type used in the matmul).

MaheshRavishankar commented 1 week ago

Yeah, the encoding needs to represent how this data is accessed in the eventual operation... but at the point you are setting the encoding you can encode that logic into the indexing maps.

I don't think this is possible. The user_indexing_maps are used to determine the contraction dimensions of the operation. In cases where the encoding is set on the input of a broadcast, some of the dimensions would be missing, and the contractions dims would not be properly inferred. We need a separate indexing map to show how the encoded tensor's dims map to the originalType (i.e., the type used in the matmul).

I think that is just a fix to the contraction dims inference rather than a new field.

hanhanW commented 1 week ago

status update, I have a change that sets encodings for matmul, see below snippet. The next step is bubble up the encodings op across generic ops, set the bcast_map if there is a broadcast semantics in the generic op, and hoist the set_encodings op out of dispatch region.

#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
module {
  util.func public @broadcasting_dequant_op(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @broadcasting_dequant_op(%input0: tensor<?x?xi8>, %input1: tensor<?x?x?xi32>) -> (%output0: tensor<?x?x?xi32>)"}} {
    %c0_i32 = arith.constant 0 : i32
    %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
    %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
    %2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?xi8>{%0, %1}
    %3 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
    %4 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
    %5 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[2] : index
    %6 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<?x?x?xi32>{%3, %4, %5}
    %7 = flow.dispatch.region -> (tensor<?x?x?xi32>{%3, %0, %4}) {
      %9 = tensor.empty(%3, %0, %1) : tensor<?x?x?xi32>
      %10 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<?x?xi8>) outs(%9 : tensor<?x?x?xi32>) {
      ^bb0(%in: i8, %out: i32):
        %17 = arith.extui %in : i8 to i32
        linalg.yield %17 : i32
      } -> tensor<?x?x?xi32>
      %11 = iree_encoding.set_encoding %10 : tensor<?x?x?xi32> -> tensor<?x?x?xi32, #iree_encoding.encoding<operand_index = 0 : index, element_types = [i8, i32, i32], original_type = tensor<?x?x?xi32>, user_indexing_maps = [#map2, #map3, #map4], round_dims_to = array<i64: 16, 16, 16, 16>>>
      %12 = iree_encoding.set_encoding %6 : tensor<?x?x?xi32> -> tensor<?x?x?xi32, #iree_encoding.encoding<operand_index = 1 : index, element_types = [i8, i32, i32], original_type = tensor<?x?x?xi32>, user_indexing_maps = [#map2, #map3, #map4], round_dims_to = array<i64: 16, 16, 16, 16>>>
      %13 = tensor.empty(%3, %0, %4) : tensor<?x?x?xi32, #iree_encoding.encoding<operand_index = 2 : index, element_types = [i8, i32, i32], original_type = tensor<?x?x?xi32>, user_indexing_maps = [#map2, #map3, #map4], round_dims_to = array<i64: 16, 16, 16, 16>>>
      %14 = linalg.fill ins(%c0_i32 : i32) outs(%13 : tensor<?x?x?xi32, #iree_encoding.encoding<operand_index = 2 : index, element_types = [i8, i32, i32], original_type = tensor<?x?x?xi32>, user_indexing_maps = [#map2, #map3, #map4], round_dims_to = array<i64: 16, 16, 16, 16>>>) -> tensor<?x?x?xi32, #iree_encoding.encoding<operand_index = 2 : index, element_types = [i8, i32, i32], original_type = tensor<?x?x?xi32>, user_indexing_maps = [#map2, #map3, #map4], round_dims_to = array<i64: 16, 16, 16, 16>>>
      %15 = linalg.batch_matmul_transpose_b ins(%11, %12 : tensor<?x?x?xi32, #iree_encoding.encoding<operand_index = 0 : index, element_types = [i8, i32, i32], original_type = tensor<?x?x?xi32>, user_indexing_maps = [#map2, #map3, #map4], round_dims_to = array<i64: 16, 16, 16, 16>>>, tensor<?x?x?xi32, #iree_encoding.encoding<operand_index = 1 : index, element_types = [i8, i32, i32], original_type = tensor<?x?x?xi32>, user_indexing_maps = [#map2, #map3, #map4], round_dims_to = array<i64: 16, 16, 16, 16>>>) outs(%14 : tensor<?x?x?xi32, #iree_encoding.encoding<operand_index = 2 : index, element_types = [i8, i32, i32], original_type = tensor<?x?x?xi32>, user_indexing_maps = [#map2, #map3, #map4], round_dims_to = array<i64: 16, 16, 16, 16>>>) -> tensor<?x?x?xi32, #iree_encoding.encoding<operand_index = 2 : index, element_types = [i8, i32, i32], original_type = tensor<?x?x?xi32>, user_indexing_maps = [#map2, #map3, #map4], round_dims_to = array<i64: 16, 16, 16, 16>>>
      %16 = iree_encoding.unset_encoding %15 : tensor<?x?x?xi32, #iree_encoding.encoding<operand_index = 2 : index, element_types = [i8, i32, i32], original_type = tensor<?x?x?xi32>, user_indexing_maps = [#map2, #map3, #map4], round_dims_to = array<i64: 16, 16, 16, 16>>> -> tensor<?x?x?xi32>
      %extracted_slice = tensor.extract_slice %16[0, 0, 0] [%3, %0, %4] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?x?xi32>
      flow.return %extracted_slice : tensor<?x?x?xi32>
    }
    %8 = hal.tensor.export %7 "output0" : tensor<?x?x?xi32>{%3, %0, %4} -> !hal.buffer_view
    util.return %8 : !hal.buffer_view
  }
}

Mahesh, I understand that you have concerns about bcast_map. I'll try to connect more pieces and see if it is needed or not. My intuition tells me that it is needed, like I explained in https://github.com/iree-org/iree/pull/17763#issuecomment-2195806300. Let's see..

(inlining the comment below for convenience)

Say that we have broadcasting_deqaunt + matmul in a dispatch, and we need to allocate a buffer for the input of the dispatch; the original indexing_map from batch_matmul is (b, m, n, k) -> (b, n, k), and we are broadcasting across the batch dimension (i.e., the indexing_map in broadcast op is (b, n, k) -> (n, k)). At stream level, we want to allocate a n * k buffer. If we don't encode the broadcast indexing map, we don't know the allocation size. The missing dimension could be any of b, n and k. I think this is the cost we pay for broadcast fusion.

I think we can't reuse the original indexing_maps field, because it makes logics very very tricky. We could need the original matmul indexing maps to infer contraction dims. Replacing the corresponding indexing map with broadcast map is very bad to me. Decoupling the logic out to a new field reduces the complexity.

MaheshRavishankar commented 1 week ago

Cross posting response as well

"One way to maybe fix it is to change the indexing maps to be an list(list(indexing_maps)). So basically a composition of the indexing maps allows you to get back to the original op. For example for the broadcast case the indexing map would be [[affine_map<(b, m, n, k) -> (b, n, k)>, affine_map<(b,n,k) -> (n, k)>], .., ...] Then this also generalizes (say you want to fuse the transpose in or some arbitrary chain in the future."

hanhanW commented 1 week ago

I push the changes to https://github.com/iree-org/iree/tree/shared/data-tiling-fusion

Run e2e compilation flow: iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=znver4 --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu --iree-opt-data-tiling=false --iree-flow-enable-data-tiling ~/matmul.mlir -o /tmp/z.vmfb

The current status is that all the ops are in the same dispatch. As mentioned in other comments, next steps are moving the set_encodings out and fusing them with their producers. Max and I will work on this.

util.func public @broadcasting_dequant_op(%arg0 : tensor<?x?xi8>, %rhs : tensor<?x?x?xi32>) -> tensor<?x?x?xi32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %d1 = tensor.dim %arg0, %c0 : tensor<?x?xi8>
  %d2 = tensor.dim %arg0, %c1 : tensor<?x?xi8>
  %d0 = tensor.dim %rhs, %c0 : tensor<?x?x?xi32>
  %empty = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xi32>
  %dequant = linalg.generic {
      indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
                       affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
      iterator_types = ["parallel", "parallel", "parallel"]}
      ins(%arg0 : tensor<?x?xi8>) outs(%empty : tensor<?x?x?xi32>) {
    ^bb0(%in: i8, %out: i32):
      %12 = arith.extui %in : i8 to i32
      linalg.yield %12 : i32
    } -> tensor<?x?x?xi32>

  %B = tensor.dim %dequant, %c0 : tensor<?x?x?xi32>
  %M = tensor.dim %dequant, %c1 : tensor<?x?x?xi32>
  %N = tensor.dim %rhs, %c1 : tensor<?x?x?xi32>
  %init = tensor.empty(%B, %M, %N) : tensor<?x?x?xi32>
  %cst = arith.constant 0 : i32
  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
  %op = linalg.batch_matmul_transpose_b
      ins(%dequant, %rhs : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
      outs(%fill : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
  util.return %op : tensor<?x?x?xi32>
}
hanhanW commented 4 days ago

I took a stab at writing down the IR for materialization pass. The first two set encoding ops should be hoisted to other new dispatches, but it is okay for testing. The pass should be able to handle the case. I'm working on teaching the materialization patterns to take bcast_map into account. So far it just crashes in my prototype.

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @broadcasting_deqaunt_matmul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {
  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
} {

  %M = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %B = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
  %N = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
  %K = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[2] : index

  %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?xi8>{%M, %K}
  %1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<?x?x?xi32>{%B, %N, %K}

  %encoding_0 = iree_encoding.set_encoding %0 : tensor<?x?xi8>
    -> tensor<?x?xi8, #iree_encoding.encoding<
        operand_index = 0 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        bcast_map = #map3,
        round_dims_to = array<i64: 16, 16, 16>>>

  %encoding_1 = iree_encoding.set_encoding %1 : tensor<?x?x?xi32>
    -> tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 1 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        round_dims_to = array<i64: 16, 16, 16>>>

  %3 = tensor.empty(%B, %M, %K) : tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 0 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        bcast_map = #map3,
        round_dims_to = array<i64: 16, 16, 16>>>

  %4 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "parallel"]}
    ins(%encoding_0 : tensor<?x?xi8, #iree_encoding.encoding<
        operand_index = 0 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        bcast_map = #map3,
        round_dims_to = array<i64: 16, 16, 16>>>)
     outs(%3 : tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 0 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        bcast_map = #map3,
        round_dims_to = array<i64: 16, 16, 16>>>) {
  ^bb0(%in: i8, %out: i32):
    %5 = arith.extui %in : i8 to i32
    linalg.yield %5 : i32
  } -> tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 0 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        bcast_map = #map3,
        round_dims_to = array<i64: 16, 16, 16>>>

  %c0_i32 = arith.constant 0 : i32
  %5 = tensor.empty(%B, %M, %N) : tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 2 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        round_dims_to = array<i64: 16, 16, 16>>>
  %6 = linalg.fill
    ins(%c0_i32 : i32)
    outs(%5 : tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 2 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        round_dims_to = array<i64: 16, 16, 16>>>)
    -> tensor<?x?x?xi32, #iree_encoding.encoding<
    operand_index = 2 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        round_dims_to = array<i64: 16, 16, 16>>>

  %7 = linalg.batch_matmul_transpose_b
    ins(%4, %encoding_1 :
      tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 0 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        bcast_map = #map3,
        round_dims_to = array<i64: 16, 16, 16>>>,
      tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 1 : index,
        element_types = [i8, i32, i32],
        user_indexing_maps = [#map, #map1, #map2],
        round_dims_to = array<i64: 16, 16, 16>>>)
    outs(%6 :
      tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 2 : index,
            element_types = [i8, i32, i32],
            user_indexing_maps = [#map, #map1, #map2],
            round_dims_to = array<i64: 16, 16, 16>>>)
    -> tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 2 : index,
            element_types = [i8, i32, i32],
            user_indexing_maps = [#map, #map1, #map2],
            round_dims_to = array<i64: 16, 16, 16>>>
  %8 = iree_encoding.unset_encoding %7
    : tensor<?x?x?xi32, #iree_encoding.encoding<
        operand_index = 2 : index,
            element_types = [i8, i32, i32],
            user_indexing_maps = [#map, #map1, #map2],
            round_dims_to = array<i64: 16, 16, 16>>>
    -> tensor<?x?x?xi32>
  %extracted_slice = tensor.extract_slice %8[0, 0, 0] [%B, %M, %N] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?x?xi32>
  %9 = hal.tensor.export %extracted_slice "output0" : tensor<?x?x?xi32>{%B, %M, %N} -> !hal.buffer_view

  func.return %9 : !hal.buffer_view
}