iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.59k stars 581 forks source link

Supporting non aligned reduction splitting #13115

Open ThomasRaoux opened 1 year ago

ThomasRaoux commented 1 year ago

Request description

Currently reduction splitting pass only handle cases where the reduction dimension is divisible by the given ratio.

This is the case because the transformation is happening before dispatch region creation and for simplicity and not having to handle loops at the graph level the transformation is done on linalg to linalg. Therefore the transformation converts:

  %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
                    outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>

into:

    %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32>
    %expanded_0 = tensor.expand_shape %arg1 [[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32>
    %0 = tensor.empty() : tensor<4x16x32xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x16x32xf32>) -> tensor<4x16x32xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%expanded, %expanded_0 : tensor<16x4x64xf32>, tensor<4x64x32xf32>) outs(%1 : tensor<4x16x32xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %4 = arith.mulf %in, %in_1 : f32
      %5 = arith.addf %out, %4 : f32
      linalg.yield %5 : f32
    } -> tensor<4x16x32xf32>
    %3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["reduction", "parallel", "parallel"]} ins(%2 : tensor<4x16x32xf32>) outs(%arg2 : tensor<16x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %4 = arith.addf %in, %out : f32
      linalg.yield %4 : f32
    } -> tensor<16x32xf32>

The downside of this is that the expanded tensor needs to be hyper-rectangular: %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32>. This is not the case when the split ratio is not divisible by the reduction dimension.

Here are the potential solutions to relax this restriction:

1) The main challenge is that for unaligned cases the resulting matmul op is not a batch matmul but grouped gemm (batch of matmul with different sizes). Being able to represent that with a structured op would allow us to do this transformation at the graph level and would work once codegen supports grouped gemm. As non hyper-rectangular shapes cannot be represented we would need to introduce a new ad hoc op that would represent what needs to be done without expanding the tensors.

The IR would look like:

    %0 = tensor.empty() : tensor<4x16x32xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x16x32xf32>) -> tensor<4x16x32xf32>
    %2 = linalg.ext_split_grouped_gemm(%arg0 : tensor<16x256xf32>, %arg1 : tensor<256x32xf32>, %1: tensor<4x16x32xf32>) -> tensor<4x16x32xf32>
    %3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["reduction", "parallel", "parallel"]} ins(%2 : tensor<4x16x32xf32>) outs(%arg2 : tensor<16x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %4 = arith.addf %in, %out : f32
      linalg.yield %4 : f32
    } -> tensor<16x32xf32>

2) Introducing scf.forall loops at the graph level. The same transformation also exists in a form analogue to tiling to scf that does linalg-> linalg + loops.

In this case the IR looks like:

    %0 = tensor.empty() : tensor<16x32x4xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
    %2 = scf.forall (%arg3) in (4) shared_outs(%arg4 = %1) -> (tensor<16x32x4xf32>) {
      %extracted_slice = tensor.extract_slice %arg4[0, 0, %arg3] [16, 32, 1] [1, 1, 1] : tensor<16x32x4xf32> to tensor<16x32xf32>
      %4 = affine.apply #map(%arg3)
      %extracted_slice_0 = tensor.extract_slice %arg0[0, %4] [16, 64] [1, 1] : tensor<16x256xf32> to tensor<16x64xf32>
      %extracted_slice_1 = tensor.extract_slice %arg1[%4, 0] [64, 32] [1, 1] : tensor<256x32xf32> to tensor<64x32xf32>
      %5 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x64xf32>, tensor<64x32xf32>) outs(%extracted_slice : tensor<16x32xf32>) -> tensor<16x32xf32>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %5 into %arg4[0, 0, %arg3] [16, 32, 1] [1, 1, 1] : tensor<16x32xf32> into tensor<16x32x4xf32>
      }
    }
    %3 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%2 : tensor<16x32x4xf32>) outs(%arg2 : tensor<16x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %4 = arith.addf %in, %out : f32
      linalg.yield %4 : f32
    } -> tensor<16x32xf32>

Then the scf.forall has to become the root of a dispatch regions (fused with the linalg.fill). This would require allowing scf.forall at the graph level, it should be minimal changes as this is can be treated like any other structured ops but this is something that opens a significantly different strategy than what IREE currently support and we need to be sure this fits IREE long term direction. This will require the backend to support such loops and know how to distribute it.

@nicolasvasilache @MaheshRavishankar @stellaraccident @mattwalsh @benvanik

We need to decide on the direction to lift the restrictions on split-K. The most concrete solution at this point is to introduce loops before dispatch region creation, it would be good to decide if this is something we want to support in IREE.

Otherwise it would be good to figure out if there are ways to relax our tensor types/linalg ops to be able to represent that without loops. @nicolasvasilache what do you think?

What component(s) does this issue relate to?

No response

Additional context

No response

MaheshRavishankar commented 1 year ago

Solution 2 is what I had in mind as well. With https://github.com/openxla/iree/pull/13038 it would also resolve issues w.r.t workgroup count region resolution as well. So this works well AFAICS.

ThomasRaoux commented 1 year ago

Solution 2 is what I had in mind as well. With #13038 it would also resolve issues w.r.t workgroup count region resolution as well. So this works well AFAICS.

good to know. The more I think about the more I think it might work fine also. There will be a bit or work in the codegen to support distributing two levels of scf.forall but it might be fine. This is probably something we can prototype without too much trouble.

MaheshRavishankar commented 1 year ago

Solution 2 is what I had in mind as well. With #13038 it would also resolve issues w.r.t workgroup count region resolution as well. So this works well AFAICS.

good to know. The more I think about the more I think it might work fine also. There will be a bit or work in the codegen to support distributing two levels of scf.forall but it might be fine. This is probably something we can prototype without too much trouble.

Yes that will be the complication. I don't know easy ways of doing that cause it is the same level of distribution. W.r.t workgroup count computation I am adding an op to transform dialect to connect everything. That might need to evolve.

nicolasvasilache commented 1 year ago

Good that people's thoughts are evolving in this direction: scf.forall at the graph level is the one key abstraction identified a while ago that has been hard to get buy in on. If the magic words are simply "split K is better represented this way", this works for me !

On Sat, Apr 15, 2023 at 9:17 PM MaheshRavishankar @.***> wrote:

Solution 2 is what I had in mind as well. With #13038 https://github.com/openxla/iree/pull/13038 it would also resolve issues w.r.t workgroup count region resolution as well. So this works well AFAICS.

good to know. The more I think about the more I think it might work fine also. There will be a bit or work in the codegen to support distributing two levels of scf.forall but it might be fine. This is probably something we can prototype without too much trouble.

Yes that will be the complication. I don't know easy ways of doing that cause it is the same level of distribution. W.r.t workgroup count computation I am adding an op to transform dialect to connect everything. That might need to evolve.

— Reply to this email directly, view it on GitHub https://github.com/openxla/iree/issues/13115#issuecomment-1509935330, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACNNU5CCVQGPRAKELLUIO43XBLX55ANCNFSM6AAAAAAW7SY6PY . You are receiving this because you were mentioned.Message ID: @.***>

-- N

MaheshRavishankar commented 1 year ago

Good that people's thoughts are evolving in this direction: scf.forall at the graph level is the one key abstraction identified a while ago that has been hard to get buy in on. If the magic words are simply "split K is better represented this way", this works for me !

Just to clarify... IMO this is not meant to say we should use scf.forall as a replacement for flow.dispatch.region or flow.dispatch.workgroups. Those still remain. We can use the scf.forall to partition the work for split K, but this op itself should be moved into a flow.dispatch.region (and then into a flow.dispatch.workgroup). I hope we aren't talking about multi-dimensional tile+distribute using parameteric values. I would still push back against that path. Here to I actually dont care much about the op that is representing splitk. Some container to represent the split-k is all I am going for. Could be scf.forall, could be linalg.generic as it is today, or linalg_ext.splitk_grouped_gemm, they are all equivalent AFAIC.

MaheshRavishankar commented 1 year ago

Thinking a little bit more about it. If fusion is important, and we are using scf.for_all instead of having separate heuristics, using what is here https://github.com/openxla/iree/blob/7171c452580a3bf6606491a08ae57083ad5bdf64/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp#L572 is probably better. This actually makes it an argument against using scf.for_all and an argument for using a separate op for split-k. Having a different op will also mean we dont need to worry about multiple scf.for_all for different dimensions. We can have an op that implements the TilingInterface and tile and distribute everything at once. I am still partial towards the scf.for_all operation since that will make the "split-k" path available for other operations if it can be plumbed through in a reasonable way.

MaheshRavishankar commented 1 year ago

Adding a few more questions here 1) There is still an open question about how to determine when to apply split-k.... I think that needs to happen in preprocessing step (cause it is dependent on what the target is and shape of the problem, decisions of such nature dont fit into Flow). So while this improves the mechanism of split-k (really glad that this is being done, this was what I have always wanted to change split-k to) it still needs to move out of Flow and into preprocessing. Any pointers to what the blockers from doing this move are?

2) There is still a complication in the backend

Lets say we get this in the backend.

    %0 = tensor.empty() : tensor<16x32x4xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
    %2 = scf.forall (%arg3) in (4) shared_outs(%arg4 = %1) -> (tensor<16x32x4xf32>) {
      %extracted_slice = tensor.extract_slice %arg4[0, 0, %arg3] [16, 32, 1] [1, 1, 1] : tensor<16x32x4xf32> to tensor<16x32xf32>
      %4 = affine.apply #map(%arg3)
      %extracted_slice_0 = tensor.extract_slice %arg0[0, %4] [16, 64] [1, 1] : tensor<16x256xf32> to tensor<16x64xf32>
      %extracted_slice_1 = tensor.extract_slice %arg1[%4, 0] [64, 32] [1, 1] : tensor<256x32xf32> to tensor<64x32xf32>
      %5 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x64xf32>, tensor<64x32xf32>) outs(%extracted_slice : tensor<16x32xf32>) -> tensor<16x32xf32>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %5 into %arg4[0, 0, %arg3] [16, 32, 1] [1, 1, 1] : tensor<16x32xf32> into tensor<16x32x4xf32>
      }
    }

Now we will have to tile and distribute the inner linalg.matmul along m and n.

So this will become

 %0 = tensor.empty() : tensor<16x32x4xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
    %2 = scf.forall (%arg3) in (4) shared_outs(%arg4 = %1) -> (tensor<16x32x4xf32>) {
      %extracted_slice = tensor.extract_slice %arg4[0, 0, %arg3] [16, 32, 1] [1, 1, 1] : tensor<16x32x4xf32> to tensor<16x32xf32>
      %4 = affine.apply #map(%arg3)
      %extracted_slice_0 = tensor.extract_slice %arg0[0, %4] [16, 64] [1, 1] : tensor<16x256xf32> to tensor<16x64xf32>
      %extracted_slice_1 = tensor.extract_slice %arg1[%4, 0] [64, 32] [1, 1] : tensor<256x32xf32> to tensor<64x32xf32>
      %5 = scf.forall .... {
          ...
          ... = linalg.matmul
      }
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %5 into %arg4[0, 0, %arg3] [16, 32, 1] [1, 1, 1] : tensor<16x32xf32> into tensor<16x32x4xf32>
      }
    }

So this is a nested scf.forall. How do we know that the inner scf.forall is targeting the same level of "parallelism heirarchy" (i.e thread blocks) as the outer scf.forall (as opposed to targeting a different level of the parallelism heirarchy like warps/threads on GPU).

allieculp commented 1 year ago

@MaheshRavishankar I assume this issue is waiting for Thomas to return to office next week. Let us know if any other update in the meantime.

MaheshRavishankar commented 1 year ago

Yup. Need to get an idea of what the end state here is. I think more details need to be worked out here to get a full picture

allieculp commented 1 year ago

Plan to discuss in GPU sync on Tuesday 5/9

ThomasRaoux commented 1 year ago

Thinking a little bit more about it. If fusion is important, and we are using scf.for_all instead of having separate heuristics, using what is here

https://github.com/openxla/iree/blob/7171c452580a3bf6606491a08ae57083ad5bdf64/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp#L572

is probably better. This actually makes it an argument against using scf.for_all and an argument for using a separate op for split-k. Having a different op will also mean we dont need to worry about multiple scf.for_all for different dimensions. We can have an op that implements the TilingInterface and tile and distribute everything at once. I am still partial towards the scf.for_all operation since that will make the "split-k" path available for other operations if it can be plumbed through in a reasonable way.

Note that to match what currently happen we only need to fuse fill with the scf.forall (and this is important), since the merge op would not have any scf.forall the fusion of the merge with producer should work out of the box. However depending on the path we take to fuse ops like cast, broadcast, etc... in the consumer matmul this may be a problem indeed. This is definitely something we need to consider but I'm not sure we have enough visibility on what will happen there. @MaheshRavishankar this will need your input.

  1. There is still an open question about how to determine when to apply split-k.... I think that needs to happen in preprocessing step (cause it is dependent on what the target is and shape of the problem, decisions of such nature dont fit into Flow). So while this improves the mechanism of split-k (really glad that this is being done, this was what I have always wanted to change split-k to) it still needs to move out of Flow and into preprocessing. Any pointers to what the blockers from doing this move are?

Correct, It is clear this needs to happen in pre-processing so we should make sure whatever solution we pick works there but both solution mentioned above would work there so I think it makes sense to treat it as an orthogonal problem and solve it separately.

Any pointers to what the blockers from doing this move are?

The main blocker is that we don't have a place to start adding control that depend on the target, the current pre-processing is mostly command line based and for this to work well we need to be able to add some basic heuristic based on target. Maybe the new nvgpu repo is the place although it is still not very clear to me where this code would go exactly.

So this is a nested scf.forall. How do we know that the inner scf.forall is targeting the same level of "parallelism heirarchy" (i.e thread blocks) as the outer scf.forall (as opposed to targeting a different level of the parallelism heirarchy like warps/threads on GPU).

The scf.forall are marked with the ids it is distributing on so it would be explicit in the IR. However I agree dealing with the potentially differences due to the nested scf.forall is going to be one of the main difficulties with this approach, we need to make sure we are able to re-use most of the matmul codegen otherwise this solution is going to be very hard to deal with.

nicolasvasilache commented 1 year ago
The main blocker is that we don't have a place to start adding control that depend on the target, the current pre-processing is mostly command line based and for this to work well we need to be able to add some basic heuristic based on target. Maybe the new nvgpu repo is the place although it is still not very clear to me where this code would go exactly.

+1, it would be great to have a place to add pluggable preprocessing strategies.

MaheshRavishankar commented 1 year ago

Thinking a little bit more about it. If fusion is important, and we are using scf.for_all instead of having separate heuristics, using what is here https://github.com/openxla/iree/blob/7171c452580a3bf6606491a08ae57083ad5bdf64/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp#L572

is probably better. This actually makes it an argument against using scf.for_all and an argument for using a separate op for split-k. Having a different op will also mean we dont need to worry about multiple scf.for_all for different dimensions. We can have an op that implements the TilingInterface and tile and distribute everything at once. I am still partial towards the scf.for_all operation since that will make the "split-k" path available for other operations if it can be plumbed through in a reasonable way.

Note that to match what currently happen we only need to fuse fill with the scf.forall (and this is important), since the merge op would not have any scf.forall the fusion of the merge with producer should work out of the box. However depending on the path we take to fuse ops like cast, broadcast, etc... in the consumer matmul this may be a problem indeed. This is definitely something we need to consider but I'm not sure we have enough visibility on what will happen there. @MaheshRavishankar this will need your input.

One option I have been toying with is to instead implement split-k just before FormDispatchWorkgroups pass, i.e. after flow.dispatch.regions have been formed already... You can split a flow.dispatch.region into two. Still has the issue of how you control it... I have some ideas, but I need to sync with Ben to make sure this works e2e.

  1. There is still an open question about how to determine when to apply split-k.... I think that needs to happen in preprocessing step (cause it is dependent on what the target is and shape of the problem, decisions of such nature dont fit into Flow). So while this improves the mechanism of split-k (really glad that this is being done, this was what I have always wanted to change split-k to) it still needs to move out of Flow and into preprocessing. Any pointers to what the blockers from doing this move are?

Correct, It is clear this needs to happen in pre-processing so we should make sure whatever solution we pick works there but both solution mentioned above would work there so I think it makes sense to treat it as an orthogonal problem and solve it separately.

Any pointers to what the blockers from doing this move are?

The main blocker is that we don't have a place to start adding control that depend on the target, the current pre-processing is mostly command line based and for this to work well we need to be able to add some basic heuristic based on target. Maybe the new nvgpu repo is the place although it is still not very clear to me where this code would go exactly.

I dont think the command line based options are that limiting... they are basically any pass option you can add to -pass-pipeline... so should have almost no restrictions.

ThomasRaoux commented 1 year ago
  1. There is still an open question about how to determine when to apply split-k.... I think that needs to happen in preprocessing step (cause it is dependent on what the target is and shape of the problem, decisions of such nature dont fit into Flow). So while this improves the mechanism of split-k (really glad that this is being done, this was what I have always wanted to change split-k to) it still needs to move out of Flow and into preprocessing. Any pointers to what the blockers from doing this move are?

Correct, It is clear this needs to happen in pre-processing so we should make sure whatever solution we pick works there but both solution mentioned above would work there so I think it makes sense to treat it as an orthogonal problem and solve it separately.

Any pointers to what the blockers from doing this move are?

The main blocker is that we don't have a place to start adding control that depend on the target, the current pre-processing is mostly command line based and for this to work well we need to be able to add some basic heuristic based on target. Maybe the new nvgpu repo is the place although it is still not very clear to me where this code would go exactly.

I dont think the command line based options are that limiting... they are basically any pass option you can add to -pass-pipeline... so should have almost no restrictions.

Could you give more details on what you have in mind here? Options is one thing but we need some logic somewhere to decide per op decision. We could potentially make some logic based on some platform info (like "number of threads" available on the target) and pass this as the option, is that what you had in mind? This would still most likely still end up being biased toward certain targets though.

nicolasvasilache commented 1 year ago

Coming back to this after vacation, I see quite a few alternatives being discussed, some more of less speculative. I'll do my best to answer in a structured way by just showing what we have available starting from this commit https://github.com/iree-org/iree-samples/commit/e030b0a49ce434dd286390f7fa7b66ebbccf0fad.

You can dig into more details by looking here https://github.com/iree-org/iree-samples/commit/e030b0a49ce434dd286390f7fa7b66ebbccf0fad#diff-4d6be20e8b5295b925af9509d5de4e32aba9e0837317eece2fd7c8d631c3e795R2 and repro by copy-pasting the commands provided (and occasionally commenting out parts of the script to see intermediate state).

Step 0: Input IR.

  func.func @matmul_static(%arg0: tensor<123x51234xf32>, %arg1: tensor<51234x456xf32>, %arg2: tensor<123x456xf32>) -> tensor<123x456xf32> {
    %0 = linalg.matmul ins(%arg0, %arg1 : tensor<123x51234xf32>, tensor<51234x456xf32>) outs(%arg2 : tensor<123x456xf32>) -> tensor<123x456xf32>
    return %0 : tensor<123x456xf32>
  }

Step 1: Split-K and fuse fill with scf.forall at the graph level.

  func.func @matmul_static(%arg0: tensor<123x51234xf32>, %arg1: tensor<51234x456xf32>, %arg2: tensor<123x456xf32>) -> tensor<123x456xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<123x456x77xf32>
    %1 = scf.forall (%arg3) in (77) shared_outs(%arg4 = %0) -> (tensor<123x456x77xf32>) {
      %3 = affine.min #map(%arg3)
      %4 = affine.max #map1(%3)
      %extracted_slice = tensor.extract_slice %arg4[0, 0, %arg3] [123, 456, 1] [1, 1, 1] : tensor<123x456x77xf32> to tensor<123x456x1xf32>
      %5 = linalg.fill ins(%cst : f32) outs(%extracted_slice : tensor<123x456x1xf32>) -> tensor<123x456x1xf32>
      %extracted_slice_0 = tensor.extract_slice %5[0, 0, 0] [123, 456, 1] [1, 1, 1] : tensor<123x456x1xf32> to tensor<123x456xf32>
      %6 = affine.apply #map2(%arg3)
      %extracted_slice_1 = tensor.extract_slice %arg0[0, %6] [123, %4] [1, 1] : tensor<123x51234xf32> to tensor<123x?xf32>
      %extracted_slice_2 = tensor.extract_slice %arg1[%6, 0] [%4, 456] [1, 1] : tensor<51234x456xf32> to tensor<?x456xf32>
      %7 = linalg.matmul ins(%extracted_slice_1, %extracted_slice_2 : tensor<123x?xf32>, tensor<?x456xf32>) outs(%extracted_slice_0 : tensor<123x456xf32>) -> tensor<123x456xf32>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %7 into %arg4[0, 0, %arg3] [123, 456, 1] [1, 1, 1] : tensor<123x456xf32> into tensor<123x456x77xf32>
      }
    } {mapping = [#gpu.block<z>]}
    %2 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1 : tensor<123x456x77xf32>) outs(%arg2 : tensor<123x456xf32>) {
    ^bb0(%in: f32, %out: f32):
      %3 = arith.addf %in, %out : f32
      linalg.yield %3 : f32
    } -> tensor<123x456xf32>
    return %2 : tensor<123x456xf32>
  }

Step 2: Convert the scf.forall to form a custom dispatch.workgroups.

  func.func @matmul_static(%arg0: tensor<123x51234xf32>, %arg1: tensor<51234x456xf32>, %arg2: tensor<123x456xf32>) -> tensor<123x456xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<123x456x77xf32>
    %1 = flow.dispatch.workgroups(%cst, %arg0, %arg1, %0) : (f32, tensor<123x51234xf32>, tensor<51234x456xf32>, tensor<123x456x77xf32>) -> tensor<123x456x77xf32> =
        (%arg3: f32, %arg4: !flow.dispatch.tensor<readonly:tensor<123x51234xf32>>, %arg5: !flow.dispatch.tensor<readonly:tensor<51234x456xf32>>, %arg6: !flow.dispatch.tensor<readonly:tensor<123x456x77xf32>>, %arg7: !flow.dispatch.tensor<writeonly:tensor<123x456x77xf32>>) {
      %3 = flow.dispatch.tensor.load %arg4, offsets = [0, 0], sizes = [123, 51234], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<123x51234xf32>> -> tensor<123x51234xf32>
      %4 = flow.dispatch.tensor.load %arg5, offsets = [0, 0], sizes = [51234, 456], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<51234x456xf32>> -> tensor<51234x456xf32>
      %5 = flow.dispatch.tensor.load %arg6, offsets = [0, 0, 0], sizes = [123, 456, 77], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<123x456x77xf32>> -> tensor<123x456x77xf32>
      %6 = scf.forall (%arg8) in (77) shared_outs(%arg9 = %5) -> (tensor<123x456x77xf32>) {
        %7 = affine.min #map(%arg8)
        %8 = affine.max #map1(%7)
        %extracted_slice = tensor.extract_slice %arg9[0, 0, %arg8] [123, 456, 1] [1, 1, 1] : tensor<123x456x77xf32> to tensor<123x456x1xf32>
        %9 = linalg.fill ins(%arg3 : f32) outs(%extracted_slice : tensor<123x456x1xf32>) -> tensor<123x456x1xf32>
        %extracted_slice_0 = tensor.extract_slice %9[0, 0, 0] [123, 456, 1] [1, 1, 1] : tensor<123x456x1xf32> to tensor<123x456xf32>
        %10 = affine.apply #map2(%arg8)
        %extracted_slice_1 = tensor.extract_slice %3[0, %10] [123, %8] [1, 1] : tensor<123x51234xf32> to tensor<123x?xf32>
        %extracted_slice_2 = tensor.extract_slice %4[%10, 0] [%8, 456] [1, 1] : tensor<51234x456xf32> to tensor<?x456xf32>
        %11 = linalg.matmul ins(%extracted_slice_1, %extracted_slice_2 : tensor<123x?xf32>, tensor<?x456xf32>) outs(%extracted_slice_0 : tensor<123x456xf32>) -> tensor<123x456xf32>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %11 into %arg9[0, 0, %arg8] [123, 456, 1] [1, 1, 1] : tensor<123x456xf32> into tensor<123x456x77xf32>
        }
      } {mapping = [#gpu.block<z>]}
      flow.dispatch.tensor.store %6, %arg7, offsets = [0, 0, 0], sizes = [123, 456, 77], strides = [1, 1, 1] : tensor<123x456x77xf32> -> !flow.dispatch.tensor<writeonly:tensor<123x456x77xf32>>
      flow.return
    }
    %2 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1 : tensor<123x456x77xf32>) outs(%arg2 : tensor<123x456xf32>) {
    ^bb0(%in: f32, %out: f32):
      %3 = arith.addf %in, %out : f32
      linalg.yield %3 : f32
    } -> tensor<123x456xf32>
    return %2 : tensor<123x456xf32>
  }

Step 3: Use IREE for the other dispatch (custom dispatch formation and IREE's dispatch formation just compose)

module attributes {hal.device.targets = [#device_target_cuda]} {
  hal.executable private @matmul_static_dispatch_0 {
    hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb {
      hal.executable.export public @matmul_static_dispatch_0_matmul_123x456xD_f32 ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} {
      ^bb0(%arg0: !hal.device):
        %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
        hal.return %x, %y, %z : index, index, index
      }
      builtin.module {
        func.func @matmul_static_dispatch_0_matmul_123x456xD_f32() {
          %c0 = arith.constant 0 : index
          %c17275136 = arith.constant 17275136 : index
          %cst = arith.constant 0.000000e+00 : f32
          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<123x51234xf32>>
          ...
          %5 = scf.forall (%arg0) in (77) shared_outs(%arg1 = %4) -> (tensor<123x456x77xf32>) {
            ...
            %12 = linalg.matmul ins(%10, %11 : tensor<123x?xf32>, tensor<?x456xf32>) outs(%extracted_slice_0 : tensor<123x456xf32>) -> tensor<123x456xf32>
            scf.forall.in_parallel {
              tensor.parallel_insert_slice %12 into %arg1[0, 0, %arg0] [123, 456, 1] [1, 1, 1] : tensor<123x456xf32> into tensor<123x456x77xf32>
            }
          } {mapping = [#gpu.block<z>]}
          flow.dispatch.tensor.store %5, %3, offsets = [0, 0, 0], sizes = [123, 456, 77], strides = [1, 1, 1] : tensor<123x456x77xf32> -> !flow.dispatch.tensor<writeonly:tensor<123x456x77xf32>>
          return
        }
      }
    }
  }
  hal.executable private @matmul_static_dispatch_1 {
    hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb {
      hal.executable.export public @matmul_static_dispatch_1_generic_56088x77_f32 ordinal(0) layout(#pipeline_layout1) attributes {translation_info = #translation} {
      ^bb0(%arg0: !hal.device):
        %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
        hal.return %x, %y, %z : index, index, index
      }
      builtin.module {
        func.func @matmul_static_dispatch_1_generic_56088x77_f32() {
          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c17275136) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<56088x77xf32>>
          ...
          %4 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<56088x77xf32>) outs(%3 : tensor<56088xf32>) {
          ^bb0(%in: f32, %out: f32):
            %5 = arith.addf %in, %out : f32
            linalg.yield %5 : f32
          } -> tensor<56088xf32>
          flow.dispatch.tensor.store %4, %1, offsets = [0], sizes = [56088], strides = [1] : tensor<56088xf32> -> !flow.dispatch.tensor<readwrite:tensor<56088xf32>>
          return
        }
      }
    }
  }

  func.func @matmul_static(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    ...
    %0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<123x51234xf32> in !stream.resource<external>{%c25207128}
    ...
    %2 = stream.tensor.import %arg2 : !hal.buffer_view -> tensor<123x456xf32> in !stream.resource<external>{%c224352}
    ...
    stream.cmd.dispatch     @matmul_static_dispatch_0::@cuda_nvptx_fb::@matmul_static_dispatch_0_matmul_123x456xD_f32 
    stream.cmd.dispatch @matmul_static_dispatch_1::@cuda_nvptx_fb::@matmul_static_dispatch_1_generic_56088x77_f32  
    ... 
}

Step 4: Apply custom codegen on the matmul

2 tasks need to be solved to make this work today (see the FIXMEs).

For now this is just the first step of unaligned matmul codegen.

  hal.executable private @matmul_static_dispatch_0 {
    hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb {
      hal.executable.export public @matmul_static_dispatch_0_matmul_123x456xD_f32 ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation, workgroup_size = [1 : index, 1 : index, 1 : index]} {
      ^bb0(%arg0: !hal.device):
        %c1 = arith.constant 1 : index
        %c77 = arith.constant 77 : index
        hal.return %c1, %c1, %c77 : index, index, index
      }
      builtin.module {
        func.func @matmul_static_dispatch_0_matmul_123x456xD_f32() {
          %c0 = arith.constant 0 : index
          %c0_0 = arith.constant 0 : index
          %c17275136 = arith.constant 17275136 : index
          %cst = arith.constant 0.000000e+00 : f32
          %c16 = arith.constant 16 : index
          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0_0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<123x51234xf32>>
          %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0_0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<51234x456xf32>>
          %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0_0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<123x456x77xf32>>
          %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c17275136) : !flow.dispatch.tensor<writeonly:tensor<123x456x77xf32>>
          %4 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [123, 456, 77], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<123x456x77xf32>> -> tensor<123x456x77xf32>
          %5 = scf.forall (%arg0) in (77) shared_outs(%arg1 = %4) -> (tensor<123x456x77xf32>) {
            %6 = affine.min #map(%arg0)
            %7 = affine.max #map1(%6)
            %extracted_slice = tensor.extract_slice %arg1[0, 0, %arg0] [123, 456, 1] [1, 1, 1] : tensor<123x456x77xf32> to tensor<123x456x1xf32>
            %8 = linalg.fill ins(%cst : f32) outs(%extracted_slice : tensor<123x456x1xf32>) -> tensor<123x456x1xf32>
            %extracted_slice_1 = tensor.extract_slice %8[0, 0, 0] [123, 456, 1] [1, 1, 1] : tensor<123x456x1xf32> to tensor<123x456xf32>
            %9 = affine.apply #map2(%arg0)
            %10 = flow.dispatch.tensor.load %0, offsets = [0, %9], sizes = [123, %7], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<123x51234xf32>> -> tensor<123x?xf32>
            %11 = flow.dispatch.tensor.load %1, offsets = [%9, 0], sizes = [%7, 456], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<51234x456xf32>> -> tensor<?x456xf32>
            %12 = scf.forall (%arg2, %arg3) in (1, 4) shared_outs(%arg4 = %extracted_slice_1) -> (tensor<123x456xf32>) {
              %13 = affine.min #map3(%arg3)
              %14 = affine.apply #map4(%arg2)
              %15 = affine.apply #map4(%arg3)
              %extracted_slice_2 = tensor.extract_slice %10[%14, 0] [123, %7] [1, 1] : tensor<123x?xf32> to tensor<123x?xf32>
              %extracted_slice_3 = tensor.extract_slice %11[0, %15] [%7, %13] [1, 1] : tensor<?x456xf32> to tensor<?x?xf32>
              %extracted_slice_4 = tensor.extract_slice %arg4[%14, %15] [123, %13] [1, 1] : tensor<123x456xf32> to tensor<123x?xf32>
              %16 = scf.for %arg5 = %c0_0 to %7 step %c16 iter_args(%arg6 = %extracted_slice_4) -> (tensor<123x?xf32>) {
                %17 = affine.min #map5(%arg5)[%7]
                %extracted_slice_5 = tensor.extract_slice %extracted_slice_3[%arg5, 0] [%17, %13] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
                %extracted_slice_6 = tensor.extract_slice %arg6[0, 0] [123, %13] [1, 1] : tensor<123x?xf32> to tensor<123x?xf32>
                %extracted_slice_7 = tensor.extract_slice %extracted_slice_2[0, %arg5] [123, %17] [1, 1] : tensor<123x?xf32> to tensor<123x?xf32>
                %18 = linalg.matmul ins(%extracted_slice_7, %extracted_slice_5 : tensor<123x?xf32>, tensor<?x?xf32>) outs(%extracted_slice_6 : tensor<123x?xf32>) -> tensor<123x?xf32>
                %inserted_slice = tensor.insert_slice %18 into %arg6[0, 0] [123, %13] [1, 1] : tensor<123x?xf32> into tensor<123x?xf32>
                scf.yield %inserted_slice : tensor<123x?xf32>
              }
              scf.forall.in_parallel {
                tensor.parallel_insert_slice %16 into %arg4[%14, %15] [123, %13] [1, 1] : tensor<123x?xf32> into tensor<123x456xf32>
              }
            } {mapping = [#gpu.block<y>, #gpu.block<x>]}
            scf.forall.in_parallel {
              tensor.parallel_insert_slice %12 into %arg1[0, 0, %arg0] [123, 456, 1] [1, 1, 1] : tensor<123x456xf32> into tensor<123x456x77xf32>
            }
          } {mapping = [#gpu.block<z>]}
          flow.dispatch.tensor.store %5, %3, offsets = [0, 0, 0], sizes = [123, 456, 77], strides = [1, 1, 1] : tensor<123x456x77xf32> -> !flow.dispatch.tensor<writeonly:tensor<123x456x77xf32>>
          return
        }
      }
    }
  }

Step 5: Use IREE for the codegen of the other dispatch (custom dispatch formation and IREE's dispatch formation just compose)

Calling iree-compile will just take care of the unaligned reduction transparently. The 2 tasks mentioned above need to be fixed first.

Hopefully this clearly demonstrates that we have all the pieces in place to achieve split-K (and many more things) by simply applying a subset the codegen abstractions at the graph level and having it compose by construction with all the rest of IREE.

@stellaraccident @mattwalsh @manishucsd @qcolombet with whom I spoke about some of this this week.

nicolasvasilache commented 1 year ago

I believe this becomes available in IREE with little effort and involves:

  1. extend iree.populate_workgroup_count_region_using_num_threads_slice to work on nested scf.forall on blocks
  2. extend iree.iree.forall_to_workgroup to work on nested scf.forall on blocks
  3. optionally cleanup/unify iree.iree.forall_to_workgroup with iree. map_nested_forall_to_gpu_threads which also brings a bunch of extra benefits that I won't detail here.
  4. allow such pluggable graph-level processings to compose naturally with IREE by landing an updated version of #11886, or an equivalent solution to drive this with the same matchers and transforms we already use for some of the CUDA codegen.
MaheshRavishankar commented 1 year ago

Coming back to this after vacation, I see quite a few alternatives being discussed, some more of less speculative. I'll do my best to answer in a structured way by just showing what we have available starting from this commit iree-org/iree-samples@e030b0a.

You can dig into more details by looking here iree-org/iree-samples@e030b0a#diff-4d6be20e8b5295b925af9509d5de4e32aba9e0837317eece2fd7c8d631c3e795R2 and repro by copy-pasting the commands provided (and occasionally commenting out parts of the script to see intermediate state).

Step 0: Input IR.

  func.func @matmul_static(%arg0: tensor<123x51234xf32>, %arg1: tensor<51234x456xf32>, %arg2: tensor<123x456xf32>) -> tensor<123x456xf32> {
    %0 = linalg.matmul ins(%arg0, %arg1 : tensor<123x51234xf32>, tensor<51234x456xf32>) outs(%arg2 : tensor<123x456xf32>) -> tensor<123x456xf32>
    return %0 : tensor<123x456xf32>
  }

Step 1: Split-K and fuse fill with scf.forall at the graph level.

  func.func @matmul_static(%arg0: tensor<123x51234xf32>, %arg1: tensor<51234x456xf32>, %arg2: tensor<123x456xf32>) -> tensor<123x456xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<123x456x77xf32>
    %1 = scf.forall (%arg3) in (77) shared_outs(%arg4 = %0) -> (tensor<123x456x77xf32>) {
      %3 = affine.min #map(%arg3)
      %4 = affine.max #map1(%3)
      %extracted_slice = tensor.extract_slice %arg4[0, 0, %arg3] [123, 456, 1] [1, 1, 1] : tensor<123x456x77xf32> to tensor<123x456x1xf32>
      %5 = linalg.fill ins(%cst : f32) outs(%extracted_slice : tensor<123x456x1xf32>) -> tensor<123x456x1xf32>
      %extracted_slice_0 = tensor.extract_slice %5[0, 0, 0] [123, 456, 1] [1, 1, 1] : tensor<123x456x1xf32> to tensor<123x456xf32>
      %6 = affine.apply #map2(%arg3)
      %extracted_slice_1 = tensor.extract_slice %arg0[0, %6] [123, %4] [1, 1] : tensor<123x51234xf32> to tensor<123x?xf32>
      %extracted_slice_2 = tensor.extract_slice %arg1[%6, 0] [%4, 456] [1, 1] : tensor<51234x456xf32> to tensor<?x456xf32>
      %7 = linalg.matmul ins(%extracted_slice_1, %extracted_slice_2 : tensor<123x?xf32>, tensor<?x456xf32>) outs(%extracted_slice_0 : tensor<123x456xf32>) -> tensor<123x456xf32>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %7 into %arg4[0, 0, %arg3] [123, 456, 1] [1, 1, 1] : tensor<123x456xf32> into tensor<123x456x77xf32>
      }
    } {mapping = [#gpu.block<z>]}
    %2 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1 : tensor<123x456x77xf32>) outs(%arg2 : tensor<123x456xf32>) {
    ^bb0(%in: f32, %out: f32):
      %3 = arith.addf %in, %out : f32
      linalg.yield %3 : f32
    } -> tensor<123x456xf32>
    return %2 : tensor<123x456xf32>
  }

Step 2: Convert the scf.forall to form a custom dispatch.workgroups.

  func.func @matmul_static(%arg0: tensor<123x51234xf32>, %arg1: tensor<51234x456xf32>, %arg2: tensor<123x456xf32>) -> tensor<123x456xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<123x456x77xf32>
    %1 = flow.dispatch.workgroups(%cst, %arg0, %arg1, %0) : (f32, tensor<123x51234xf32>, tensor<51234x456xf32>, tensor<123x456x77xf32>) -> tensor<123x456x77xf32> =
        (%arg3: f32, %arg4: !flow.dispatch.tensor<readonly:tensor<123x51234xf32>>, %arg5: !flow.dispatch.tensor<readonly:tensor<51234x456xf32>>, %arg6: !flow.dispatch.tensor<readonly:tensor<123x456x77xf32>>, %arg7: !flow.dispatch.tensor<writeonly:tensor<123x456x77xf32>>) {
      %3 = flow.dispatch.tensor.load %arg4, offsets = [0, 0], sizes = [123, 51234], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<123x51234xf32>> -> tensor<123x51234xf32>
      %4 = flow.dispatch.tensor.load %arg5, offsets = [0, 0], sizes = [51234, 456], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<51234x456xf32>> -> tensor<51234x456xf32>
      %5 = flow.dispatch.tensor.load %arg6, offsets = [0, 0, 0], sizes = [123, 456, 77], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<123x456x77xf32>> -> tensor<123x456x77xf32>
      %6 = scf.forall (%arg8) in (77) shared_outs(%arg9 = %5) -> (tensor<123x456x77xf32>) {
        %7 = affine.min #map(%arg8)
        %8 = affine.max #map1(%7)
        %extracted_slice = tensor.extract_slice %arg9[0, 0, %arg8] [123, 456, 1] [1, 1, 1] : tensor<123x456x77xf32> to tensor<123x456x1xf32>
        %9 = linalg.fill ins(%arg3 : f32) outs(%extracted_slice : tensor<123x456x1xf32>) -> tensor<123x456x1xf32>
        %extracted_slice_0 = tensor.extract_slice %9[0, 0, 0] [123, 456, 1] [1, 1, 1] : tensor<123x456x1xf32> to tensor<123x456xf32>
        %10 = affine.apply #map2(%arg8)
        %extracted_slice_1 = tensor.extract_slice %3[0, %10] [123, %8] [1, 1] : tensor<123x51234xf32> to tensor<123x?xf32>
        %extracted_slice_2 = tensor.extract_slice %4[%10, 0] [%8, 456] [1, 1] : tensor<51234x456xf32> to tensor<?x456xf32>
        %11 = linalg.matmul ins(%extracted_slice_1, %extracted_slice_2 : tensor<123x?xf32>, tensor<?x456xf32>) outs(%extracted_slice_0 : tensor<123x456xf32>) -> tensor<123x456xf32>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %11 into %arg9[0, 0, %arg8] [123, 456, 1] [1, 1, 1] : tensor<123x456xf32> into tensor<123x456x77xf32>
        }
      } {mapping = [#gpu.block<z>]}
      flow.dispatch.tensor.store %6, %arg7, offsets = [0, 0, 0], sizes = [123, 456, 77], strides = [1, 1, 1] : tensor<123x456x77xf32> -> !flow.dispatch.tensor<writeonly:tensor<123x456x77xf32>>
      flow.return
    }
    %2 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1 : tensor<123x456x77xf32>) outs(%arg2 : tensor<123x456xf32>) {
    ^bb0(%in: f32, %out: f32):
      %3 = arith.addf %in, %out : f32
      linalg.yield %3 : f32
    } -> tensor<123x456xf32>
    return %2 : tensor<123x456xf32>
  }

The issue here is that it works by knowing the split-k value statically. Also use of transform dialect script for the whole program is unclear to me (it will work for a single GEMM and is useful for flushing out backend issues). The end state of split-k work is to allow for the split value to be dynamic, but that seems it is out of scope for the current work.

I also think (for this use case) scf.forall -> flow.dispatch.workgroups is reverse of what I would expect. There is also an issue of doing this "early" means this doesn't fuse the same way as the current non-split K path. The final split-k solution should happen in conjunction with how dispatch regions are formed today.

I am looking for what the expected end-state here. If it is for just getting GEMM performance up with split-k this sounds good. Making this shippable will need work.

nicolasvasilache commented 1 year ago

This required higher BW so we connected offline last week about what the expected end-state is here.

Jolting some of those thoughts down, my memory may be lapsing so please correct me if I'm wrong.

The end state in IREE involves considerations related to:

  1. specifying the split amount "KK" dynamically to be target-agnostic
  2. deciding where this "KK" is computed (host and/or device), how the graph and codegen share that convention and how codegen adapts to this
  3. considering tradeoffs related to if/else "KK" on the host to avoid potentially launching an extra empty kernel
  4. avoiding creating a custom dispatch region and custom fusion decision but instead: i. using the default IREE fusion heuristic ii. using the default IREE dispatch region formation iii. applying the transformation to scf.forall within the dispatch iv. performing a subsequent splitting of the dispatch in 2 other dispatches.

Now I am unclear that the above is truly retargetable (I am happy to believe it and adapt) or that the above considerations are achievable "soon" but once this end state was enunciated we narrowed down to what we can actually do to make progress.

The suggestion was to implement a new custom pass just before dispatch region formation and put it [here]() that:

  1. implements the graph level part of iree-org/iree-samples@e030b0a
  2. uses static tile size to avoid conflating all concerns before we get something viable
  3. uses the same transformations demonstrated in iree-org/iree-samples@e030b0a, including non-default IREE fusion for now
  4. using the underlying C++ implementation of the transforms in that new custom pass (i.e. copying the underlying C++ to avoid staging through the transform dialect)
  5. apply steps 1-3 from the previous suggestion.

Now I am unclear that this is particularly retargetable and it will probably need to live below yet another flag at the flow level but it seems actionnable enough to implement iree-org/iree-samples@e030b0a and iterate later.

MaheshRavishankar commented 1 year ago

This required higher BW so we connected offline last week about what the expected end-state is here.

Jolting some of those thoughts down, my memory may be lapsing so please correct me if I'm wrong.

The end state in IREE involves considerations related to:

  1. specifying the split amount "KK" dynamically to be target-agnostic
  2. deciding where this "KK" is computed (host and/or device), how the graph and codegen share that convention and how codegen adapts to this
  3. considering tradeoffs related to if/else "KK" on the host to avoid potentially launching an extra empty kernel
  4. avoiding creating a custom dispatch region and custom fusion decision but instead: i. using the default IREE fusion heuristic ii. using the default IREE dispatch region formation iii. applying the transformation to scf.forall within the dispatch iv. performing a subsequent splitting of the dispatch in 2 other dispatches.

Now I am unclear that the above is truly retargetable (I am happy to believe it and adapt) or that the above considerations are achievable "soon" but once this end state was enunciated we narrowed down to what we can actually do to make progress.

The suggestion was to implement a new custom pass just before dispatch region formation and put it here that:

  1. implements the graph level part of iree-org/iree-samples@e030b0a
  2. uses static tile size to avoid conflating all concerns before we get something viable
  3. uses the same transformations demonstrated in iree-org/iree-samples@e030b0a, including non-default IREE fusion for now
  4. using the underlying C++ implementation of the transforms in that new custom pass (i.e. copying the underlying C++ to avoid staging through the transform dialect)
  5. apply steps 1-3 from the previous suggestion.

1-4 sounds good. Unclear which the previous suggestion was. https://github.com/openxla/iree/issues/13115#issuecomment-1543844223 ? I thought the plan was to drop (line 69 here https://github.com/iree-org/iree-samples/commit/e030b0a49ce434dd286390f7fa7b66ebbccf0fad#diff-4d6be20e8b5295b925af9509d5de4e32aba9e0837317eece2fd7c8d631c3e795R69) and just form the flow.dispatch.region, and use the regular flow to form flow.dispatch.workgroups from that (dont want too many diverging paths that do the same thing)

Now I am unclear that this is particularly retargetable and it will probably need to live below yet another flag at the flow level but it seems actionnable enough to implement iree-org/iree-samples@e030b0a and iterate later.

Yes, will need to be in flag cause static size needs arch information. I will take a stab at addressing this for good, but that will have to be later... At least this will unblock the current issues.

nicolasvasilache commented 1 year ago

Unclear which the previous suggestion was. #13115 (comment) ?

The link points (at least for me) to the nested scf.forall mapping support (iree.populate_workgroup_count_region_using_num_threads_slice to work on nested scf.forall on blocks etc)

MaheshRavishankar commented 1 year ago

Unclear which the previous suggestion was. #13115 (comment) ?

The link points (at least for me) to the nested scf.forall mapping support (iree.populate_workgroup_count_region_using_num_threads_slice to work on nested scf.forall on blocks etc)

Ok, happy to work through details. Just recording that we want to go from scf.forall to flow.dispatch.region and use the normal conversion from flow.dispatch.region -> flow.dispatch.workgroups. Happy to help with any issues this can have (not worth having two paths doing the same things and exposing all sorts of implementation details of the conversion for it to be usable outside of the pass.)

allieculp commented 1 year ago

From meeting today, moving to @qcolombet.

ftynse commented 1 year ago

A sample script for splitting using TD is available in iree-samples (https://github.com/iree-org/iree-samples/commit/05ee2175f1bea82e7d853f66217667f5497467ef) and is being ported to on-the-fly generation via an extension in openxla-nvgpu.

qcolombet commented 1 year ago

Thanks @ftynse for the update. To capture what I discussed with @nicolasvasilache as well, the perf is not totally where we would like to be because split-k pushes us in batch-matmul territory and it turns out our perf story there is not great yet.