Open nirvedhmeshram opened 3 weeks ago
I am realizing that the matmul-like op shared in the gist in the issue is quite egde-casey so we might want to look at the whole model and think why we have this matmul-like op and should we have done something differently in pre-processing to not reach this op. Here is the front-end program causing this shape
module {
func.func @main_graph(%arg1: !torch.vtensor<[1,64,112,112],f32> , %arg2: !torch.vtensor<[1,128,56,56],f32>, %arg3: !torch.vtensor<[64,64,3,3],f32>, %arg4: !torch.vtensor<[64],f32> , %arg5: !torch.vtensor<[128,64,1,1],f32>, %arg6: !torch.vtensor<[128],f32>) -> !torch.vtensor<[1,128,56,56],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.contrib = 1 : si64, ai.onnx.ml = 4 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.nchwc = 1 : si64, org.pytorch.aten = 1 : si64}, torch.onnx_meta.producer_name = "vai_q_onnx", torch.onnx_meta.producer_version = "1.17.0+43059a7"} {
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si8>} : () -> !torch.vtensor<[],si8>
%2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1.562500e-02> : tensor<f32>} : () -> !torch.vtensor<[],f32>
%3 = torch.operator "onnx.QuantizeLinear"(%arg2, %2, %1) : (!torch.vtensor<[1,128,56,56],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[1,128,56,56],si8>
%4 = torch.operator "onnx.Conv"(%arg1, %arg3, %arg4) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>) -> !torch.vtensor<[1,64,56,56],f32>
%5 = torch.operator "onnx.DequantizeLinear"(%3, %2, %1) : (!torch.vtensor<[1,128,56,56],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[1,128,56,56],f32>
%6 = torch.operator "onnx.Relu"(%4) : (!torch.vtensor<[1,64,56,56],f32>) -> !torch.vtensor<[1,64,56,56],f32>
%7 = torch.operator "onnx.Conv"(%6, %arg5, %arg6) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [1 : si64, 1 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,64,56,56],f32>, !torch.vtensor<[128,64,1,1],f32>, !torch.vtensor<[128],f32>) -> !torch.vtensor<[1,128,56,56],f32>
%8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si8>} : () -> !torch.vtensor<[],si8>
%9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1.562500e-02> : tensor<f32>} : () -> !torch.vtensor<[],f32>
%10 = torch.operator "onnx.QuantizeLinear"(%7, %9, %8) : (!torch.vtensor<[1,128,56,56],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[1,128,56,56],si8>
%11 = torch.operator "onnx.DequantizeLinear"(%10, %9, %8) : (!torch.vtensor<[1,128,56,56],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[1,128,56,56],f32>
%12 = torch.operator "onnx.Add"(%11, %5) : (!torch.vtensor<[1,128,56,56],f32>, !torch.vtensor<[1,128,56,56],f32>) -> !torch.vtensor<[1,128,56,56],f32>
%13 = torch.operator "onnx.Relu"(%12) : (!torch.vtensor<[1,128,56,56],f32>) -> !torch.vtensor<[1,128,56,56],f32>
return %13 : !torch.vtensor<[1,128,56,56],f32>
}
}
cc @MaheshRavishankar @IanWood1
The matmul is coming from %7 = torch.operator "onnx.Conv"
which gets converted to (compile-to=input):
%13 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%12, %4 : tensor<1x64x56x56xf32>, tensor<128x64x1x1xf32>) outs(%broadcasted_3 : tensor<1x128x56x56xf32>) -> tensor<1x128x56x56xf32>
And then generalized to the linalg.generic
matmul-like op (from https://github.com/iree-org/iree/pull/18736). @nirvedhmeshram I'm not sure how this could be better represented or what the issue is, however.
The matmul is coming from
%7 = torch.operator "onnx.Conv"
which gets converted to (compile-to=input):%13 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%12, %4 : tensor<1x64x56x56xf32>, tensor<128x64x1x1xf32>) outs(%broadcasted_3 : tensor<1x128x56x56xf32>) -> tensor<1x128x56x56xf32> And then generalized to the
linalg.generic
matmul-like op (from #18736). @nirvedhmeshram I'm not sure how this could be better represented or what the issue is, however.
Thanks for taking a look, I dont think there is any reason it cant be supported but maybe just needs new configuration logic and see if anything breaks, so we can take a look at adding that.
So far my progress is:
GPUCheckResourceUsagePass
- It is quite understandable why this way lowering would exceed the limit of LDS size: The IR is attempting to load a fair amount of tensors naively from global to LDS.--iree-codegen-llvmgpu-test-tile-and-fuse-matmul
would yield into exactly similar failure as baseline SIMT--iree-codegen-llvmgpu-test-tile-and-fuse-vectorize
would successfully pass resource check and compile128x64x1x1
; input: 1x64x56x56
; output: 1x128x56x56
128x64
; gemmB: 64x56x56
; gemmC: 128x56x56
%4 = torch.operator "onnx.Conv"(%arg1, %arg3, %arg4) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>) -> !torch.vtensor<[1,64,56,56],f32>
%6 = torch.operator "onnx.Relu"(%4) : (!torch.vtensor<[1,64,56,56],f32>) -> !torch.vtensor<[1,64,56,56],f32>
Next steps:
tileAndFuse matmul
and tileAndFuse vectorze
and see in particular how the two options impacted the LLVMGPUSelectLoweringStrategyPass
, and understand why one pass but the other failtileAndFuse matmul
is able to handle itmatmul like op that except that gemmB argument have one additional trailing dimension. If I concat [d1, d2] together, then this is exactly a matmul
Yeah, this is something that should be handled by compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp but its currently blocked by some codegen issues (should be resolved shortly (https://github.com/iree-org/iree/pull/18822):
I quickly commented this line out and was able to get it to compile successfully
@IanWood1 Thanks for putting extra effort in getting it to compile. It is not immediately clear to me which line you've commented out is able to make it work. But I assume you've made it such that 56x56 will be collapsed to a single dimension. With that single dimension tileAndFuse Matmul
would be able to work correctly (versus, @nirvedhmeshram pointed out that he thinks 56x56 not divisible by 64 could be the problem, and kindly enough provided me the tracking ticket to make it work, here). Assuming I understand the background all clearly, what this means is that once #18822 get merged, the symptom of this ticket will go away as you've attempted. That's good news and re-assures me that this ticket will be gone even if I don't do anything about it :-p
I chatted with @nirvedhmeshram a moment ago, I think a few things still don't fully explain itself and worth following through on:
next steps
plan to get to confirm that the reason existing compilation failed is indeed what we assume it is Summary and conclusions:
Why tileAndFuse matmul was failing?
It failed because it fails in deducing a valid MMA schedule (setMatmulLoweringConfig()
). In particular, The gemm dimensions are [m: 56, n: 128, k: 64], and it is not divisible by 16x16x4 mfma instruction, according to below therefore giving up. https://github.com/iree-org/iree/blob/842bcbc2427779db72ee8405cd3bd6f80ceacc53/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp#L233-L235
Why tileAndFuse vectorize can pass?
It is going down a different routine (setTileAndFuseLoweringConfig()
). For the config size of [128, 56, 56]
, the picked workgroup tile size is [4, 8, 8]
; per thread tile size is [1, 1, 4]
. (side topic:) Considering that we got 64 threads per wave, this configuration is using 1 wave per workgroup. From the performance point of view, the occupancy of the generated MFA kernel seems to be quite poor.
Why would I be able to pass tileAndFuse matmul by pulling out IR before SelectLoweringStrategy pass?
Because I forget to use --iree-codegen-llvmgpu-test-tile-and-fuse-matmul
argument when invoking iree-opt
, therefore falling back to setVectorDistributionConfig()
. Although the part I don't understand is about how to make sense about the lowering_config<tile_sizes = [[1, 64, 128, 64]]>
for the dimension size of [128, 56, 56, 64]
. I'd like to understand more about how conventional SIMT pass allocate tile.
For next steps this ticket is on hold, resolution of either below will make the problem in this ticket go away:
For this matmul like + elementwise IR, we go down the LLVMGPUSIMT pipeline, see dump here . Today TileandFuse Vectorize can handle this case correctly but ideally we want this to be handled by TileandFuse Matmul pipeline.
Compile command for SIMT
Compile command for TileandFuse Vectorize