iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.49k stars 553 forks source link

Number of dims and results of reindexed AffineMap doesn't match on Vectorization #17591

Open jinchen62 opened 1 month ago

jinchen62 commented 1 month ago

What happened?

dispatch: https://gist.github.com/jinchen62/5e2af98f9b5bfc3b55e949f964459815 error log: https://gist.github.com/jinchen62/df2038b5a43ed4680804a3d7d0647d95

The failing op dumped at https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp#L336 is

%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<1x4xf32>) outs(%arg2 : tensor<1x1xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0], [1, 0], [0, 4], [0, 0]]>} { ^bb0(%in: f32, %out: f32): %11 = arith.addf %in, %out : f32 linalg.yield %11 : f32 } -> tensor<1x1xf32>

At the assertion failing point https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L474, the map is changed from (d0, d1) -> (0, d0) to (d0) -> (0, d0) so the number of dims and results doesn't match.

Steps to reproduce your issue

Run iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu dispatch_1.mlir -o test.vmfb 2> dump.mlir with TOM iree.

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

hanhanW commented 1 month ago

Inlining the mlir input below. In the beginning, I thought that the (d0, d1) -> (0, d0) is generated during codegen, but it is the case. There is (d0, d1) -> (0, d0) affine_map in the codegen's input. @jinchen62 do you know how the input is generated? It would be very helpful if you can track it back to a small set of linalg ops or tosa/torch ops. The 0 should be folded away by FoldUnitExtentDimsPass at global opt level or flow level, i.e., it should be (d0, d1) -> (d0) when it goes to codegen.

hal.executable public @main_graph$async_dispatch_1 {
  hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>) {
    hal.executable.export public @main_graph$async_dispatch_1_generic_9x1024_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main_graph$async_dispatch_1_generic_9x1024_f32() {
        %cst = arith.constant 0.000000e+00 : f32
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = arith.index_castui %0 : i32 to index
        %3 = arith.index_castui %1 : i32 to index
        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9x1024xf32>>
        %5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%3) : !flow.dispatch.tensor<writeonly:tensor<1x9xf32>>
        %6 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [9, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9x1024xf32>> -> tensor<9x1024xf32>
        %7 = tensor.empty() : tensor<1x9xf32>
        %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x9xf32>) -> tensor<1x9xf32>
        %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d0)>], iterator_types = ["parallel", "reduction"]} ins(%6 : tensor<9x1024xf32>) outs(%8 : tensor<1x9xf32>) {
        ^bb0(%in: f32, %out: f32):
          %10 = arith.addf %in, %out : f32
          linalg.yield %10 : f32
        } -> tensor<1x9xf32>
        flow.dispatch.tensor.store %9, %5, offsets = [0, 0], sizes = [1, 9], strides = [1, 1] : tensor<1x9xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x9xf32>>
        return
      }
    }
  }
}
hanhanW commented 1 month ago

I worked with @jinchen62 and we got a smaller repro: https://gist.github.com/hanhanW/b3652f5887b93fb8f0df6c6c39c1ef87

To repro, run iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fold-unit-extent-dims))" ~/repro.mlir.

Then you'll see affine_map<(d0, d1) -> (0, d0)> in the result.

#map2 = affine_map<(d0, d1) -> (d0, d1)>
#map8 = affine_map<(d0, d1) -> (0, d0)>
// ...
    %29 = linalg.generic {indexing_maps = [#map2, #map8], iterator_types = ["parallel", "reduction"]} ins(%collapsed_12 : tensor<9x1024xf32>) outs(%28 : tensor<?x9xf32>) {
    ^bb0(%in: f32, %out: f32):
      %35 = arith.addf %in, %out : f32
      linalg.yield %35 : f32
    } -> tensor<?x9xf32>
// ...
hanhanW commented 1 month ago

Actually, the input reduction op looks weird. The size of d0 mismatch. One is 1 and the other is ? It looks like there is a bug in frontend lowering. @jinchen62 you can add -mlir-print-debuginfo to iree-compile, and it will tell you where is the op lowered from. My guess is that there is a bug in XXX->Linalg lowering.

#map5 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map10 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
    %25 = tensor.empty(%12) : tensor<?x9x1xf32>
    %26 = linalg.fill ins(%cst_7 : f32) outs(%25 : tensor<?x9x1xf32>) -> tensor<?x9x1xf32>
    %27 = linalg.generic {indexing_maps = [#map5, #map10], iterator_types = ["parallel", "parallel", "reduction"]} ins(%24 : tensor<1x9x1024xf32>) outs(%26 : tensor<?x9x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %31 = arith.addf %in, %out : f32
      linalg.yield %31 : f32
    } -> tensor<?x9x1xf32>
jinchen62 commented 4 weeks ago

smaller repro: https://gist.github.com/jinchen62/91e216fb39abbb9ba4c0461346d2bb5a

command: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fold-unit-extent-dims))" repro.mlir or iree-compile --iree-hal-target-backends=llvm-cpu repro.mlir -o test.vmfb --mlir-print-ir-after-all 2> dump.mlir

hanhanW commented 4 weeks ago

@jinchen62 did you get a chance to see which op is generating the IR? The generic op looks invalid to me, like I explained in the above comment.

jinchen62 commented 4 weeks ago

I think it's

%237 = torch.aten.sum.dim_IntList %235, %236, %true, %none : !torch.vtensor<[?,9,1024],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[?,9,1],f32>
hanhanW commented 4 weeks ago

I'd suggest to check if there are bugs in torch -> linalg lowering, or other high level dialects -> torch lowering.

jinchen62 commented 4 weeks ago

torch level repro: https://gist.github.com/jinchen62/601cfce290b81e037383fc49b604a68a

iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu --iree-util-zero-fill-elided-attrs repro_torch.mlir -o test.vmfb

jinchen62 commented 3 weeks ago

part of dump torch repro: After ExpandOps (memref-expand) -> After Canonicalizer (canonicalize) https://gist.github.com/jinchen62/ae856e42b0660d0b41426e910039fb9a

@hanhanW I think with a tensor.cast op, the reduction op that you found weird should be good to compile like line381. But after Canonicalizer pass, it looks missing it like line817. The following is a compiled repro, it would fail on the same error that we are facing without the cast op at the end. Does it make sense?

#map = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
module {
  func.func @repro2(%arg0: tensor<1x9x1024xf32>) -> tensor<1x9x1xf32> {
    %cst = arith.constant dense<[false, true]> : tensor<2xi1>
    %cst_0 = arith.constant dense<1> : tensor<2xi32>
    %cst_1 = arith.constant dense<[1, -1]> : tensor<2xi32>
    %cst_2 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<2xi32>
    %1 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%cst, %cst_0, %cst_1 : tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) outs(%0 : tensor<2xi32>) {
    ^bb0(%in: i1, %in_3: i32, %in_4: i32, %out: i32):
      %6 = arith.select %in, %in_3, %in_4 : i32
      linalg.yield %6 : i32
    } -> tensor<2xi32>
    %extracted_slice = tensor.extract_slice %1[0] [1] [1] : tensor<2xi32> to tensor<1xi32>
    %collapsed = tensor.collapse_shape %extracted_slice [] : tensor<1xi32> into tensor<i32>
    %extracted = tensor.extract %collapsed[] : tensor<i32>
    %2 = arith.index_cast %extracted : i32 to index
    %3 = tensor.empty(%2) : tensor<?x9x1xf32>
    %4 = linalg.fill ins(%cst_2 : f32) outs(%3 : tensor<?x9x1xf32>) -> tensor<?x9x1xf32>
    %5 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x9x1024xf32>) outs(%4 : tensor<?x9x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %6 = arith.addf %in, %out : f32
      linalg.yield %6 : f32
    } -> tensor<?x9x1xf32>
    %cast = tensor.cast %5 : tensor<?x9x1xf32> to tensor<1x9x1xf32>
    return %cast : tensor<1x9x1xf32>
  }
}
hanhanW commented 3 weeks ago

I'm not convinced that the issue is tensor.cast. There are some shape inference passes/patterns in MLIR dialect, and they create tensor.cast op to spell out some static shapes. With the hint, the compiler is smart to fold the shape information into linalg op, which is reasonable to me. The patterns and passes are working at Linalg level, what I can think of is that the frontend is generating invalid ops.

I don't know why we're still triaging the issue at model level, perhaps I did not make it clear. Let me put it this way -- Instead of compiling the whole model, are you able to compile a single %237 = torch.aten.sum.dim_IntList %235, %236, %true, %none : !torch.vtensor<[?,9,1024],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[?,9,1],f32> op?

jinchen62 commented 3 weeks ago

I don't think it's a lowering issue. The torch.aten.sum.dim_IntList op compiles, and I traced up to the onnx->torch and didn't find a lowering bug of any onnx op.

@raikonenfnu and I think there might be an optimization bug in canonicalize pass after memref-expand. We saw the generic op with reduction dim changing from ins(%146 : tensor<?x9x1024xf32>) outs(%149 : tensor<?x?x1xf32>) to ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) with folding the tensor.cast op. We might want to see it changes to ins(%55 : tensor<?x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) or ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<1x9x1xf32>). The dump ir is here.

AmosLewis commented 2 weeks ago

I don't think it's a lowering issue. The torch.aten.sum.dim_IntList op compiles, and I traced up to the onnx->torch and didn't find a lowering bug of any onnx op.

@raikonenfnu and I think there might be an optimization bug in canonicalize pass after memref-expand. We saw the generic op with reduction dim changing from ins(%146 : tensor<?x9x1024xf32>) outs(%149 : tensor<?x?x1xf32>) to ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) with folding the tensor.cast op. We might want to see it changes to ins(%55 : tensor<?x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) or ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<1x9x1xf32>). The dump ir is here.

@jinchen62 So what's the plan to fix this issue? The bart-large model need this anyway.