iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 572 forks source link

Multi-device 2D convolution numerical imprecision #18283

Open sogartar opened 3 weeks ago

sogartar commented 3 weeks ago

I encountered some numerical imprecisions when executing on 2 local-task (most likely it is actually compiler lowering bug) devices. I have narrowed it down to a 2D convolution without padding or bias. Interestingly, with padding the result is correct.

conv2d-multi-device-numerical-imprecision.zip (updated) contains a reproducer.

My assumption is that the lowering of the different preparation of the main conv dispatch is wrong. When using padding we first 0-fill a tensor and then insert the input according to the padding size. Without padding we use the tensor directly.

sogartar commented 3 weeks ago

If I use just a single device I get the same erroneous result. If I further remove the iree.flow.transfer ops then the result is OK. I wanted to see what happens with iree.flow.transfer ops removed just before conversion to stream but got this segfault #18300.

MaheshRavishankar commented 3 weeks ago

The comment says that you can repro for a single device... your program doesnt compile with single device for me. Can you post the repro for a single device?

sogartar commented 3 weeks ago

@MaheshRavishankar, what type of compiler error did you get? Did you pick this recent fix #18217? In order to compile for a single device some manual modification is required. Namely, renaming @__device_1 -> @__device_0. Then we are still left with flow.tensor.transfer ops that cause problems.

Here are some of the program variants I tried.

I compiled the program for a single device up to excluding ConvertToStreamPass (iree-stream-conversion). Then I removed the flow.tensor.transfer ops. I get the same erroneous result. This leads me to believe something before that is going wrong or the subsequent passes choke on this particular input. I have not been able to spot a problem at that stage when looking at the IR.

sogartar commented 3 weeks ago

Here is another variant variant . It is the original multi-device program compiled up to excluding iree-flow-annotate-dispatches It is with dispatches swapped from the single device and remove flow.tesor.transfer ops. Where we know things are good. This variant produces good results. The only thing different between the dispatches is that the bad variant has conv dispatches with dynamic tensor shapes.

// Good
  flow.executable private @main$async_dispatch_2 {
    flow.executable.export public @main$async_dispatch_2_conv_2d_nchw_fchw_2x4x7x9x6x5x5_f32 workgroups() -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      flow.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main$async_dispatch_2_conv_2d_nchw_fchw_2x4x7x9x6x5x5_f32(%arg0: !flow.dispatch.tensor<readonly:tensor<2x6x11x13xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x6x5x5xf32>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<2x4x7x9xf32>>) {
        %cst = arith.constant 0.000000e+00 : f32
        %0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0, 0, 0], sizes = [2, 6, 11, 13], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x6x11x13xf32>> -> tensor<2x6x11x13xf32>
        %1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0, 0], sizes = [4, 6, 5, 5], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x6x5x5xf32>> -> tensor<4x6x5x5xf32>
        %2 = tensor.empty() : tensor<2x4x7x9xf32>
        %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x4x7x9xf32>) -> tensor<2x4x7x9xf32>
        %4 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%0, %1 : tensor<2x6x11x13xf32>, tensor<4x6x5x5xf32>) outs(%3 : tensor<2x4x7x9xf32>) -> tensor<2x4x7x9xf32>
        flow.dispatch.tensor.store %4, %arg2, offsets = [0, 0, 0, 0], sizes = [2, 4, 7, 9], strides = [1, 1, 1, 1] : tensor<2x4x7x9xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x7x9xf32>>
        return
      }
    }
  }
// Bad
  flow.executable private @main$async_dispatch_2 {
    flow.executable.export public @main$async_dispatch_2_conv_2d_nchw_fchw_2x4x7x9x6x5x5_f32 workgroups(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0, %arg1, %arg2, %arg3
      flow.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main$async_dispatch_2_conv_2d_nchw_fchw_2x4x7x9x6x5x5_f32(%arg0: !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x6x5x5xf32>>, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: !flow.dispatch.tensor<writeonly:tensor<2x4x7x9xf32>>) {
        %cst = arith.constant 0.000000e+00 : f32
        %0 = flow.dispatch.workload.ordinal %arg2, 0 : index
        %1 = flow.dispatch.workload.ordinal %arg3, 1 : index
        %2 = flow.dispatch.workload.ordinal %arg4, 2 : index
        %3 = flow.dispatch.workload.ordinal %arg5, 3 : index
        %4 = flow.dispatch.tie_shape %arg0 : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>{%0, %1, %2, %3}
        %5 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0], sizes = [%0, %1, %2, %3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>{%0, %1, %2, %3} -> tensor<?x?x?x?xf32>
        %6 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0, 0], sizes = [4, 6, 5, 5], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x6x5x5xf32>> -> tensor<4x6x5x5xf32>
        %7 = tensor.empty() : tensor<2x4x7x9xf32>
        %cast = tensor.cast %5 : tensor<?x?x?x?xf32> to tensor<2x6x11x13xf32>
        %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<2x4x7x9xf32>) -> tensor<2x4x7x9xf32>
        %9 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%cast, %6 : tensor<2x6x11x13xf32>, tensor<4x6x5x5xf32>) outs(%8 : tensor<2x4x7x9xf32>) -> tensor<2x4x7x9xf32>
        flow.dispatch.tensor.store %9, %arg6, offsets = [0, 0, 0, 0], sizes = [2, 4, 7, 9], strides = [1, 1, 1, 1] : tensor<2x4x7x9xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x7x9xf32>>
        return
      }
    }
  }

This leads me to believe that these dynamic shapes and the conversion inside the dispatch to static shapes is the culprit. Note that there are other dispatches that copy slices that also use dynamic shapes, but they are fine. It is more involved and the problem is something more specific related to the particular conv dispatch.

sogartar commented 3 weeks ago

Here is a comparison of completely isolated just dispatches compare-good-and-bad-dispatches.zip. It also manifest the same issue. Since this does not depend on multiple devices, I could validate if other backend other than llvm-cpu suffer from the same issues. This will help narrow what passes may cause the problem.

sogartar commented 3 weeks ago

I tried the comparison I posted in my previous post for the HIP backend and got correct results. Maybe the issue is with dispatch's push constants for the CPU backend.

sogartar commented 2 weeks ago

I made a modification of the faulty dispatch, where constants that are passed in are also returned return-dispatch-consts.zip. They are correct. This means that the bug is squarely in the llvm-cpu codegen pipeline.

sogartar commented 2 weeks ago

To further verify that the dynamic tensor argument to the dispatch is the culprit I edited the output of iree-hal-materialize-interfaces static_hal.interface.binding.subspan.zip. I changed the tensor type to a static tensor. This resulted in a correct computation. I tried to further isolate the problem by running a dispatch with a dynamic tensor argument that does a simple tensor copy. But this produced a correct result. This means that the problem is between the dynamic-shaped dispatch argument and the particular workload involving the linalg.conv_2d_nchw_fchw op.

MaheshRavishankar commented 2 weeks ago

@lialan could you try to take a look. I can repro locally and will try to take a look as well, but please do take a look.

MaheshRavishankar commented 2 weeks ago

Ok, I see the issue here

module {
  func.func public @main(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<4x6x5x5xf32>) -> (tensor<4xi64>, tensor<2x4x7x9xf32>) {
    %c2 = arith.constant 2 : index
    %c11 = arith.constant 11 : index
    %c13 = arith.constant 13 : index
    %c6 = arith.constant 6 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<2x4x7x9xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4x7x9xf32>) -> tensor<2x4x7x9xf32>
    %2:2 = flow.dispatch.workgroups(%arg0, %arg1, %c2, %c6, %c11, %c13) : (tensor<?x?x?x?xf32>{%c2, %c6, %c11, %c13}, tensor<4x6x5x5xf32>, index, index, index, index) -> (tensor<4xi64>, tensor<2x4x7x9xf32>) =
        (%arg2: !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>, %arg3: !flow.dispatch.tensor<readonly:tensor<4x6x5x5xf32>>, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: !flow.dispatch.tensor<writeonly:tensor<4xi6\4>>, %arg9: !flow.dispatch.tensor<writeonly:tensor<2x4x7x9xf32>>) {
      %3 = flow.dispatch.tensor.load %arg2, offsets = [0, 0, 0, 0], sizes = [%arg4, %arg5, %arg6, %arg7], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>{%arg4, %arg5, %arg6, %arg7} -> tensor<?x?x?x?xf32>
      %4 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [4, 6, 5, 5], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x6x5x5xf32>> -> tensor<4x6x5x5xf32>
      %5 = tensor.empty() : tensor<2x4x7x9xf32>
      %cst_0 = arith.constant 0.000000e+00 : f32
      %cast = tensor.cast %3 : tensor<?x?x?x?xf32> to tensor<2x6x11x13xf32>
      %6 = linalg.fill ins(%cst_0 : f32) outs(%5 : tensor<2x4x7x9xf32>) -> tensor<2x4x7x9xf32>
      %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%cast, %4 : tensor<2x6x11x13xf32>, tensor<4x6x5x5xf32>) outs(%6 : tensor<2x4x7x9xf32>) -> tensor<2x4x7x9xf32>
      flow.dispatch.tensor.store %7, %arg9, offsets = [0, 0, 0, 0], sizes = [2, 4, 7, 9], strides = [1, 1, 1, 1] : tensor<2x4x7x9xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x7x9xf32>>
      %8 = tensor.empty() : tensor<4xi64>
      %c0 = arith.constant 0 : index
      %9 = arith.index_cast %arg4 : index to i64
      %inserted = tensor.insert %9 into %8[%c0] : tensor<4xi64>
      %c1 = arith.constant 1 : index
      %10 = arith.index_cast %arg5 : index to i64
      %inserted_1 = tensor.insert %10 into %inserted[%c1] : tensor<4xi64>
      %c2_2 = arith.constant 2 : index
      %11 = arith.index_cast %arg6 : index to i64
      %inserted_3 = tensor.insert %11 into %inserted_1[%c2_2] : tensor<4xi64>
      %c3 = arith.constant 3 : index
      %12 = arith.index_cast %arg7 : index to i64
      %inserted_4 = tensor.insert %12 into %inserted_3[%c3] : tensor<4xi64>
      flow.dispatch.tensor.store %inserted_4, %arg8, offsets = [0], sizes = [4], strides = [1] : tensor<4xi64> -> !flow.dispatch.tensor<writeonly:tensor<4xi64>>
      flow.return
    }
    return %2#0, %2#1 : tensor<4xi64>, tensor<2x4x7x9xf32>
  }
}

We need to make this more robust, but I think adding a flow.dispatch.workgroup without adding a region to describe the number of workgroups is not going to work. I am not sure why this is using flow.dispatch.workgroup to begin with. I am sure if you drop the flow.dispatch.workgroups things will compile fine.

What is actually happening is that because the number of workgroups region is left empty the default workgroup region added is just wrong.... I need to look deeper into how to manage the gaurds for this, but the current expectation is that if you are adding a flow.dispatch.workgroup you should also add the number of workgroups region.

The other alternative is that you just add a flow.dispatch.region. That should also work (havent verified it, but will do shortly).

sogartar commented 2 weeks ago

I added a count region to flow.dispatch.workgroups workgroups-count.zip.

count(%arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index, index) {
    %c7 = arith.constant 7 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    flow.return %c7, %c4, %c1 : index, index, index
  }

This produces the same wrong result.

In the original dispatch comparison the good and bad variants when they get to after iree-hal-translate-target-executable-variants, both get the same workgroup counts.

Good:

  hal.executable.export public @main_dispatch_0_conv_2d_nchw_fchw_2x4x7x9x6x5x5_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
  ^bb0(%arg0: !hal.device):
    %c7 = arith.constant 7 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    hal.return %c7, %c4, %c1 : index, index, index
  }

Bad:

  hal.executable.export public @main_dispatch_0_conv_2d_nchw_fchw_2x4x7x9x6x5x5_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 8, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
  ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
    %c7 = arith.constant 7 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    hal.return %c7, %c4, %c1 : index, index, index
  }
MaheshRavishankar commented 2 weeks ago

@lialan I think I have ruled out any "structural" issues... this does seem to be a straight-up correctness bug in the CPU backend.

Just to post some notes, I tried this input

func.func public @main(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<4x6x5x5xf32>) -> (tensor<4xi64>, tensor<2x4x7x9xf32>) {
  %c2 = arith.constant 2 : index
  %c6 = arith.constant 6 : index
  %c11 = arith.constant 11 : index
  %c13 = arith.constant 13 : index
  %cst_1 = arith.constant 0.000000e+00 : f32
  %7 = tensor.empty() : tensor<2x4x7x9xf32>
  %8 = linalg.fill ins(%cst_1 : f32) outs(%7 : tensor<2x4x7x9xf32>) -> tensor<2x4x7x9xf32>
  %9, %10 = flow.dispatch.workgroups(%arg0, %arg1, %c2, %c6, %c11, %c13) :
      (tensor<?x?x?x?xf32>{%c2, %c6, %c11, %c13}, tensor<4x6x5x5xf32>, index, index, index, index) -> (tensor<4xi64>, tensor<2x4x7x9xf32>) =
      (%arg4: !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>, %arg5: !flow.dispatch.tensor<readonly:tensor<4x6x5x5xf32>>,
      %arg6: index, %arg7: index, %arg8: index, %arg9: index,
      %arg10: !flow.dispatch.tensor<writeonly:tensor<4xi64>>, %arg11: !flow.dispatch.tensor<writeonly:tensor<2x4x7x9xf32>>) {
    %18 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [%arg6, %arg7, %arg8, %arg9], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xf32>>{%arg6, %arg7, %arg8, %arg9} -> tensor<?x?x?x?xf32>
    %19 = flow.dispatch.tensor.load %arg5, offsets = [0, 0, 0, 0], sizes = [4, 6, 5, 5], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x6x5x5xf32>> -> tensor<4x6x5x5xf32>
    %20 = tensor.empty() : tensor<2x4x7x9xf32>
    %cst_17 = arith.constant 0.000000e+00 : f32
    %cast_18 = tensor.cast %18 : tensor<?x?x?x?xf32> to tensor<2x6x11x13xf32>
    %21 = linalg.fill ins(%cst_17 : f32) outs(%20 : tensor<2x4x7x9xf32>) -> tensor<2x4x7x9xf32>
    %22 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%cast_18, %19 : tensor<2x6x11x13xf32>, tensor<4x6x5x5xf32>) outs(%21 : tensor<2x4x7x9xf32>) -> tensor<2x4x7x9xf32>
    flow.dispatch.tensor.store %22, %arg11, offsets = [0, 0, 0, 0], sizes = [2, 4, 7, 9], strides = [1, 1, 1, 1] : tensor<2x4x7x9xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x7x9xf32>>

    %csts_tensor = tensor.empty() : tensor<4xi64>
    %c0 = arith.constant 0 : index
    %arg6i64 = arith.index_cast %arg6 : index to i64
    %csts_tensor1 = tensor.insert %arg6i64 into %csts_tensor[%c0] : tensor<4xi64>
    %c1 = arith.constant 1 : index
    %arg7i64 = arith.index_cast %arg7 : index to i64
    %csts_tensor2 = tensor.insert %arg7i64 into %csts_tensor1[%c1] : tensor<4xi64>
    %c2 = arith.constant 2 : index
    %arg8i64 = arith.index_cast %arg8 : index to i64
    %csts_tensor3 = tensor.insert %arg8i64 into %csts_tensor2[%c2] : tensor<4xi64>
    %c3 = arith.constant 3 : index
    %arg9i64 = arith.index_cast %arg9 : index to i64
    %csts_tensor4 = tensor.insert %arg9i64 into %csts_tensor3[%c3] : tensor<4xi64>
    flow.dispatch.tensor.store %csts_tensor4, %arg10, offsets = [0], sizes = [4], strides = [1] : tensor<4xi64> -> !flow.dispatch.tensor<writeonly:tensor<4xi64>>

    flow.return
  } count() -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_slice
    flow.return %x, %y, %z : index, index, index
  }
  return %9, %10 : tensor<4xi64>, tensor<2x4x7x9xf32>
}

(just a better version of what I'd expect to work while I was triaging). This still repros the error.

MaheshRavishankar commented 2 weeks ago

Ok, I think I found the issue. I think there is something wrong with FoldMemRefAlias pass. Attaching the IR which passes and which fails.

bad_passes_cpu.mlir.txt

good_passes_cpu.mlir.txt

This is the IR for the PASSING cases

  %15 = vector.load %0[%3, %c0, %13, %14] : memref<2x6x11x13xf32>, vector<3xf32>
  %16 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%workgroup_id_x, %7]
  %17 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%5, %10]
  %18 = vector.load %0[%3, %c1, %16, %17] : memref<2x6x11x13xf32>, vector<3xf32>
  %19 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%workgroup_id_x, %7]
  %20 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%5, %10]
  %21 = vector.load %0[%3, %c2, %19, %20] : memref<2x6x11x13xf32>, vector<3xf32>
  %22 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%workgroup_id_x, %7]
  %23 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%5, %10]
  %24 = vector.load %0[%3, %c3, %22, %23] : memref<2x6x11x13xf32>, vector<3xf32>
  %25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%workgroup_id_x, %7]
  %26 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%5, %10]
  %27 = vector.load %0[%3, %c4, %25, %26] : memref<2x6x11x13xf32>, vector<3xf32>
  %28 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%workgroup_id_x, %7]
  %29 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%5, %10]
  %30 = vector.load %0[%3, %c5, %28, %29] : memref<2x6x11x13xf32>, vector<3xf32>

and this is the IR for the failing case

  %43 = vector.load %30[%31, %c0, %41, %42] : memref<?x?x?x?xf32>, vector<3xf32>
  %44 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 1)>()[%workgroup_id_x, %35]
  %45 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %38]
  %46 = vector.load %30[%31, %c0, %44, %45] : memref<?x?x?x?xf32>, vector<3xf32>
  %47 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 2)>()[%workgroup_id_x, %35]
  %48 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %38]
  %49 = vector.load %30[%31, %c0, %47, %48] : memref<?x?x?x?xf32>, vector<3xf32>
  %50 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 3)>()[%workgroup_id_x, %35]
  %51 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %38]
  %52 = vector.load %30[%31, %c0, %50, %51] : memref<?x?x?x?xf32>, vector<3xf32>
  %53 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 4)>()[%workgroup_id_x, %35]
  %54 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %38]
  %55 = vector.load %30[%31, %c0, %53, %54] : memref<?x?x?x?xf32>, vector<3xf32>
  %56 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 5)>()[%workgroup_id_x, %35]
  %57 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %38]
  %58 = vector.load %30[%31, %c0, %56, %57] : memref<?x?x?x?xf32>, vector<3xf32>

The indexing seems off... the passing case increments the index of 1-th dimension by 1 everytime, but the failing case seems to not do that, instead is adding to the 2-th dimension.

@lialan this should be enough triage for you to fix?

lialan commented 1 week ago

Tracking this issue: it seems to be that the discrepancy in the dimension happens when trying to fold memref.subview with vetor.load. In the bad path with dynamic shape, the subview is:

`%subview_7 = memref.subview %subview_5[0, 0, 0, 0] [1, 6, 1, 3] [1, 1, 1, 1] : memref<1x6x1x3xf32, strided<[?, ?, ?, 1], offset: ?>> to memref<1x6x3xf32, strided<[?, ?, 1], offset: ?>>`

when called with getDroppedDims, the returned values are [true, false, false, false], while in the good path with static shape, the subview is:

`%subview_6 = memref.subview %subview_4[0, 0, 0, 0] [1, 6, 1, 3] [1, 1, 1, 1] : memref<1x6x1x3xf32, strided<[858, 143, 13, 1], offset: ?>> to memref<1x6x3xf32, strided<[858, 143, 1], offset: ?>>`

and getDroppedDims returns: [false, false, true, false].

lialan commented 1 week ago

In computeMemRefRankReductionMask, it is trying to figure out which dimension is dropped by looking at strides.

It does feel that by looking solely at the strides we reach a reasonable result in the dynamic shaped input, but in reality we need more information to decide which dim is dropped, as both dim 0 and dim 2 could possibly be dropped in:

memref<1x6x1x3xf32, strided<[?, ?, ?, 1], offset: ?>> to
memref<1x6x3xf32, strided<[?, ?, 1], offset: ?>>`

@MaheshRavishankar in such case which a single unbiased result is not able to be reached, should we just bail out? or maybe you have better ideas..

lialan commented 6 days ago

Edit: I was wrong with previous comment.

The dynamic dims are lowered correctly after materialization. The issue is with the memref.subview