iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 572 forks source link

Different Results for CPU and CUDA backend #14934

Closed vivekkhandelwal1 closed 1 year ago

vivekkhandelwal1 commented 1 year ago

What happened?

I'm getting a different result from the PyTorch and IREE path for the NLLLoss lowering. The result from the CPU backend is incorrect while the result from the CUDA backend is correct.

Steps to reproduce your issue

To reproduce the issue, compile and run the following IR:

#map = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> ()>
module attributes {torch.debug_module_name = "train_func"} {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @forward(%arg0: tensor<4096x8192xf32>, %arg1: tensor<4096xi64>) -> tensor<f32> {
    %cst = arith.constant 0.000000e+00 : f32
    %c-100 = arith.constant -100 : index
    %cst_0 = arith.constant 4.096000e+03 : f32
    %0 = tensor.empty() : tensor<4096xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg1 : tensor<4096xi64>) outs(%0 : tensor<4096xf32>) {
    ^bb0(%in: i64, %out: f32):
      %5 = arith.index_cast %in : i64 to index
      %6 = arith.cmpi eq, %5, %c-100 : index
      %7 = linalg.index 0 : index
      %extracted = tensor.extract %arg0[%7, %5] : tensor<4096x8192xf32>
      %8 = arith.negf %extracted : f32
      %9 = arith.select %6, %cst, %8 : f32
      linalg.yield %9 : f32
    } -> tensor<4096xf32>
    %2 = tensor.empty() : tensor<f32>
    %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<f32>) -> tensor<f32>
    %4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%1 : tensor<4096xf32>) outs(%3 : tensor<f32>) {
    ^bb0(%in: f32, %out: f32):
      %5 = arith.divf %in, %cst_0 : f32
      %6 = arith.addf %5, %out : f32
      linalg.yield %6 : f32
    } -> tensor<f32>
    return %4 : tensor<f32>
  }
}

For CPU (producing incorrect results): 1.) Compilation Command:

iree-compile --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary     --iree-hal-target-backends=llvm-cpu --mlir-print-debuginfo     --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host     --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64     --iree-vm-bytecode-module-strip-source-map=false --iree-util-zero-fill-elided-attrs     --iree-vm-target-truncate-unsupported-floats --iree-codegen-check-ir-before-llvm-conversion=false    --iree-consteval-jit-debug=true   --iree-stream-resource-max-allocation-size=3221225472 nll_loss.mlir -o nll_loss.vmfb

2.) Runtime Command:

iree-run-module --module=nll_loss.vmfb  --device=local-task --function=forward --input=@log_softmax.npy --input=@view_12.npy

Please download the files log_softmax.npy, and view_12.npy.

For CUDA (producing correct results): 1.) Compilation Command:

iree-compile --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary     --iree-hal-target-backends=cuda --mlir-print-debuginfo     --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host     --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64     --iree-vm-bytecode-module-strip-source-map=false --iree-util-zero-fill-elided-attrs     --iree-vm-target-truncate-unsupported-floats --iree-codegen-check-ir-before-llvm-conversion=false --iree-opt-const-expr-hoisting=false   --iree-consteval-jit-debug=true   --iree-stream-resource-max-allocation-size=3221225472 nll_loss.mlir -o nll_loss.vmfb

2.) Runtime Command:

iree-run-module --module=nll_loss.vmfb  --device=cuda --function=forward --input=@log_softmax.npy --input=@view_12.npy

What component(s) does this issue relate to?

MLIR, Compiler, Runtime

Version information

de4f4e539a1e1dd7f0ffe2ec2eee32b4587faa12

Additional context

No response

MaheshRavishankar commented 1 year ago

Could you give some idea of whether is this is within reassociative reording results. Could you try adding iree-llvmcpu-reassociate-fp-reductions and see if that helps.

IanNod commented 1 year ago

@MaheshRavishankar added that flag and see no difference in the iree-cpu result

aviator19941 commented 1 year ago

@MaheshRavishankar Golden value: 5.9949 Incorrect/current value: 12.0519

MaheshRavishankar commented 1 year ago

Thats a big difference.. shucks these are hard...

benvanik commented 1 year ago

Driveby: can we try vulkan to ensure that it's not CUDA-vs-the-world? (I've seen cases in the past where entire models were trained and deployed with backend-specific bugs and relied on the consistency of the issue)

benvanik commented 1 year ago

(just a thought: this could be a good case for finally wiring up hal.instrument.value - that'd give printf-debugging for any primitive SSA value and we could do it manually by inserting the op in vector dialect level or something - CUDA side needs some copy paste from the CPU side and other goo, though, so nothing immediate)

raikonenfnu commented 1 year ago

Driveby: can we try vulkan to ensure that it's not CUDA-vs-the-world? (I've seen cases in the past where entire models were trained and deployed with backend-specific bugs and relied on the consistency of the issue)

Perhaps, but the golden values are computed from good ol eager mode CPU pytorch. The odd thing is on one of our custom backend for a custom HW, we are generating similar results to IREE-CPU. 🤔

IanNod commented 1 year ago

Still digging into it more but I commented out createFusionOfTensorOpsPass (iree-flow-fusion-of-tensor-ops) pass and am now getting the correct value with iree-cpu: f32=5.99486

MaheshRavishankar commented 1 year ago

Thats really useful info! Let me take a look at that. (but that pass is run on CUDA and on CPU, so that is a bit strange)

MaheshRavishankar commented 1 year ago

Turns out there was a bug that mean the flag --iree-llvmcpu-reassociate-fp-reductions=false was not actually respected. #14965 fixes that bug. With that PR using --iree-llvmcpu-reassociate-fp-reductions=false now gives the right result.

It seems like this is falling into a category where reordering is causing an issue. We might in numerically unstable territory here. The associative reordering I see is as follows

// -----// IR Dump Before LLVMCPUSplitReduction (iree-llvmcpu-split-reduction) //----- //
func.func @forward_dispatch_0_generic_4096_i64xf32() {
  %c0 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c-100 = arith.constant -100 : index loc(callsite("repro.mlir":7:14 at "repro.mlir":5:3))
  %cst = arith.constant 4.096000e+03 : f32 loc(callsite("repro.mlir":8:14 at "repro.mlir":5:3))
  %cst_0 = arith.constant 0.000000e+00 : f32 loc(callsite("repro.mlir":6:12 at "repro.mlir":5:3))
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> loc("repro.mlir":5:3)
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xi64>> loc("repro.mlir":5:3)
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %3 = flow.dispatch.tensor.load %2, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<writeonly:tensor<f32>> -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 8192], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> -> tensor<4096x8192xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %5 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xi64>> -> tensor<4096xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[], [], [], []]>} ins(%cst_0 : f32) outs(%3 : tensor<f32>) -> tensor<f32> loc(callsite("repro.mlir":21:10 at "repro.mlir":5:3))
  %7 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%5 : tensor<4096xi64>) outs(%6 : tensor<f32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0], [0], [4], [0]]>} {
  ^bb0(%in: i64 loc("repro.mlir":11:10), %out: f32 loc("repro.mlir":23:20)):
    %8 = linalg.index 0 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
    %9 = arith.index_cast %in : i64 to index loc(callsite("repro.mlir":12:12 at "repro.mlir":5:3))
    %10 = arith.cmpi eq, %9, %c-100 : index loc(callsite("repro.mlir":13:12 at "repro.mlir":5:3))
    %extracted = tensor.extract %4[%8, %9] : tensor<4096x8192xf32> loc(callsite("repro.mlir":15:20 at "repro.mlir":5:3))
    %11 = arith.negf %extracted : f32 loc(callsite("repro.mlir":16:12 at "repro.mlir":5:3))
    %12 = arith.select %10, %cst_0, %11 : f32 loc(callsite("repro.mlir":17:12 at "repro.mlir":5:3))
    %13 = arith.divf %12, %cst : f32 loc(callsite("repro.mlir":24:12 at "repro.mlir":5:3))
    %14 = arith.addf %13, %out : f32 loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
    linalg.yield %14 : f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  } -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  flow.dispatch.tensor.store %7, %2, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  return loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
} loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))

// -----// IR Dump After LLVMCPUSplitReduction (iree-llvmcpu-split-reduction) //----- //
func.func @forward_dispatch_0_generic_4096_i64xf32() {
  %c1 = arith.constant 1 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0_0 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0_1 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c-100 = arith.constant -100 : index loc(callsite("repro.mlir":7:14 at "repro.mlir":5:3))
  %cst = arith.constant 4.096000e+03 : f32 loc(callsite("repro.mlir":8:14 at "repro.mlir":5:3))
  %cst_2 = arith.constant 0.000000e+00 : f32 loc(callsite("repro.mlir":6:12 at "repro.mlir":5:3))
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0_1) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> loc("repro.mlir":5:3)
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0_1) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xi64>> loc("repro.mlir":5:3)
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0_1) : !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %3 = flow.dispatch.tensor.load %2, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<writeonly:tensor<f32>> -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 8192], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> -> tensor<4096x8192xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %5 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xi64>> -> tensor<4096xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[], [], [], []]>} ins(%cst_2 : f32) outs(%3 : tensor<f32>) -> tensor<f32> loc(callsite("repro.mlir":21:10 at "repro.mlir":5:3))
  %c0_3 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c4096 = arith.constant 4096 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %extracted_slice = tensor.extract_slice %5[0] [4096] [1] : tensor<4096xi64> to tensor<4096xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %extracted_slice_4 = tensor.extract_slice %6[] [] [] : tensor<f32> to tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %expanded = tensor.expand_shape %extracted_slice [[0, 1]] : tensor<4096xi64> into tensor<1024x4xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %7 = tensor.empty() : tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %cst_5 = arith.constant 0.000000e+00 : f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %8 = linalg.fill ins(%cst_5 : f32) outs(%7 : tensor<4xf32>) -> tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0_6 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c1024 = arith.constant 1024 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %9 = scf.for %arg0 = %c0_6 to %c1024 step %c1 iter_args(%arg1 = %8) -> (tensor<4xf32>) {
    %c0_7 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %c4 = arith.constant 4 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %extracted_slice_8 = tensor.extract_slice %expanded[%arg0, 0] [1, 4] [1, 1] : tensor<1024x4xi64> to tensor<1x4xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %extracted_slice_9 = tensor.extract_slice %arg1[0] [4] [1] : tensor<4xf32> to tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"]} ins(%extracted_slice_8 : tensor<1x4xi64>) outs(%extracted_slice_9 : tensor<4xf32>) {
    ^bb0(%in: i64 loc("repro.mlir":11:10), %out: f32 loc("repro.mlir":23:20)):
      %12 = linalg.index 0 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
      %13 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%12, %arg0) loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
      %14 = arith.index_cast %in : i64 to index loc(callsite("repro.mlir":12:12 at "repro.mlir":5:3))
      %15 = arith.cmpi eq, %14, %c-100 : index loc(callsite("repro.mlir":13:12 at "repro.mlir":5:3))
      %extracted = tensor.extract %4[%13, %14] : tensor<4096x8192xf32> loc(callsite("repro.mlir":15:20 at "repro.mlir":5:3))
      %16 = arith.negf %extracted : f32 loc(callsite("repro.mlir":16:12 at "repro.mlir":5:3))
      %17 = arith.select %15, %cst_2, %16 : f32 loc(callsite("repro.mlir":17:12 at "repro.mlir":5:3))
      %18 = arith.divf %17, %cst : f32 loc(callsite("repro.mlir":24:12 at "repro.mlir":5:3))
      %19 = arith.addf %18, %out : f32 loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
      linalg.yield %19 : f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    } -> tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %inserted_slice = tensor.insert_slice %11 into %arg1[0] [4] [1] : tensor<4xf32> into tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    scf.yield %inserted_slice : tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  } loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %10 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%9 : tensor<4xf32>) outs(%extracted_slice_4 : tensor<f32>) {
  ^bb0(%in: f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3)), %out: f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))):
    %11 = arith.addf %in, %out : f32 loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
    linalg.yield %11 : f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  } -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  flow.dispatch.tensor.store %10, %2, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  return loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
} loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))

AFAICS this is just doing this reoreding. Instead of a + b + c + d + e + f + g + h it is doing (a + e) + (b + f) + (c + g) + (d + h), which seems OK I think... but people who understand the numerics better should weigh in (maybe @stellaraccident or @bjacob )

cc @dcaballe

dcaballe commented 1 year ago

As you described, there's always going to be certain level of fp reassociation if a reduction dimension is vectorized with split reduction because the reduction order differs from the scalar one (assuming that is the golden one). However, that level of reassociation should be along the lines of the GPU one... I understand that with your fix, we do not use split-reduction with --iree-llvmcpu-reassociate-fp-reductions=false? If the result with split reduction is still off, perhaps we have a bug there but I have seen cases where the reduction error accumulated up to several orders of magnitude from the golden value...

MaheshRavishankar commented 1 year ago

As you described, there's always going to be certain level of fp reassociation if a reduction dimension is vectorized with split reduction because the reduction order differs from the scalar one (assuming that is the golden one). However, that level of reassociation should be along the lines of the GPU one... I understand that with your fix, we do not use split-reduction with --iree-llvmcpu-reassociate-fp-reductions=false? If the result with split reduction is still off, perhaps we have a bug there but I have seen cases where the reduction error accumulated up to several orders of magnitude from the golden value...

Yes, with the flag set to false, it won't use split reduction. But us of split reduction is the cause of the issue. I think this is a numerical stability of input issue.

MaheshRavishankar commented 1 year ago

Ok, I think I found the real issue here. Seems to be a vectorization bug.

// -----// IR Dump Before GenericVectorization (iree-codegen-generic-vectorization) //----- //
func.func @forward_dispatch_0_generic_4096_i64xf32() {
  %c1024 = arith.constant 1024 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c1 = arith.constant 1 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c-100 = arith.constant -100 : index loc(callsite("repro.mlir":7:14 at "repro.mlir":5:3))
  %cst = arith.constant 4.096000e+03 : f32 loc(callsite("repro.mlir":8:14 at "repro.mlir":5:3))
  %cst_0 = arith.constant 0.000000e+00 : f32 loc(callsite("repro.mlir":6:12 at "repro.mlir":5:3))
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> loc("repro.mlir":5:3)
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xi64>> loc("repro.mlir":5:3)
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %3 = flow.dispatch.tensor.load %2, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<writeonly:tensor<f32>> -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 8192], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> -> tensor<4096x8192xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %5 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xi64>> -> tensor<4096xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[], [], [], []]>} ins(%cst_0 : f32) outs(%3 : tensor<f32>) -> tensor<f32> loc(callsite("repro.mlir":21:10 at "repro.mlir":5:3))
  %expanded = tensor.expand_shape %5 [[0, 1]] : tensor<4096xi64> into tensor<1024x4xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %7 = tensor.empty() : tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %8 = linalg.fill ins(%cst_0 : f32) outs(%7 : tensor<4xf32>) -> tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %9 = scf.for %arg0 = %c0 to %c1024 step %c1 iter_args(%arg1 = %8) -> (tensor<4xf32>) {
    %extracted_slice = tensor.extract_slice %expanded[%arg0, 0] [1, 4] [1, 1] : tensor<1024x4xi64> to tensor<1x4xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"]} ins(%extracted_slice : tensor<1x4xi64>) outs(%arg1 : tensor<4xf32>) {
    ^bb0(%in: i64 loc("repro.mlir":11:10), %out: f32 loc("repro.mlir":23:20)):
      %12 = linalg.index 0 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
      %13 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%12, %arg0) loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
      %14 = arith.index_cast %in : i64 to index loc(callsite("repro.mlir":12:12 at "repro.mlir":5:3))
      %15 = arith.cmpi eq, %14, %c-100 : index loc(callsite("repro.mlir":13:12 at "repro.mlir":5:3))
      %extracted = tensor.extract %4[%13, %14] : tensor<4096x8192xf32> loc(callsite("repro.mlir":15:20 at "repro.mlir":5:3))
      %16 = arith.negf %extracted : f32 loc(callsite("repro.mlir":16:12 at "repro.mlir":5:3))
      %17 = arith.select %15, %cst_0, %16 : f32 loc(callsite("repro.mlir":17:12 at "repro.mlir":5:3))
      %18 = arith.divf %17, %cst : f32 loc(callsite("repro.mlir":24:12 at "repro.mlir":5:3))
      %19 = arith.addf %18, %out : f32 loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
      linalg.yield %19 : f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    } -> tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    scf.yield %11 : tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  } loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %10 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%9 : tensor<4xf32>) outs(%6 : tensor<f32>) {
  ^bb0(%in: f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3)), %out: f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))):
    %11 = arith.addf %in, %out : f32 loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
    linalg.yield %11 : f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  } -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  flow.dispatch.tensor.store %10, %2, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  return loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
} loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))

// -----// IR Dump After GenericVectorization (iree-codegen-generic-vectorization) //----- //
func.func @forward_dispatch_0_generic_4096_i64xf32() {
  %cst = arith.constant dense<0.000000e+00> : vector<f32> loc(callsite("repro.mlir":6:12 at "repro.mlir":5:3))
  %cst_0 = arith.constant dense<8192> : vector<1x4xindex> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %cst_1 = arith.constant dense<4.096000e+03> : vector<1x4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %cst_2 = arith.constant dense<0.000000e+00> : vector<1x4xf32> loc(callsite("repro.mlir":15:20 at "repro.mlir":5:3))
  %cst_3 = arith.constant dense<true> : vector<1x4xi1> loc(callsite("repro.mlir":15:20 at "repro.mlir":5:3))
  %cst_4 = arith.constant dense<-100> : vector<1x4xindex> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0_i64 = arith.constant 0 : i64 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %cst_5 = arith.constant dense<0.000000e+00> : vector<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c1024 = arith.constant 1024 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c1 = arith.constant 1 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %cst_6 = arith.constant 0.000000e+00 : f32 loc(callsite("repro.mlir":6:12 at "repro.mlir":5:3))
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> loc("repro.mlir":5:3)
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xi64>> loc("repro.mlir":5:3)
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %3 = flow.dispatch.tensor.load %2, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<writeonly:tensor<f32>> -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 8192], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> -> tensor<4096x8192xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %5 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xi64>> -> tensor<4096xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %expanded = tensor.expand_shape %5 [[0, 1]] : tensor<4096xi64> into tensor<1024x4xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %6 = tensor.empty() : tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %7 = vector.transfer_write %cst_5, %6[%c0] {in_bounds = [true]} : vector<4xf32>, tensor<4xf32> loc(callsite("repro.mlir":6:12 at "repro.mlir":5:3))
  %8 = scf.for %arg0 = %c0 to %c1024 step %c1 iter_args(%arg1 = %7) -> (tensor<4xf32>) {
    %14 = vector.transfer_read %expanded[%arg0, %c0], %c0_i64 {in_bounds = [true, true]} : tensor<1024x4xi64>, vector<1x4xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %15 = vector.transfer_read %arg1[%c0], %cst_6 {in_bounds = [true]} : tensor<4xf32>, vector<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %16 = vector.broadcast %arg0 : index to vector<1x4xindex> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %17 = arith.index_cast %14 : vector<1x4xi64> to vector<1x4xindex> loc(callsite("repro.mlir":12:12 at "repro.mlir":5:3))
    %18 = arith.cmpi eq, %17, %cst_4 : vector<1x4xindex> loc(callsite("repro.mlir":13:12 at "repro.mlir":5:3))
    %19 = arith.muli %16, %cst_0 : vector<1x4xindex> loc(callsite("repro.mlir":15:20 at "repro.mlir":5:3))
    %20 = arith.addi %17, %19 : vector<1x4xindex> loc(callsite("repro.mlir":15:20 at "repro.mlir":5:3))
    %21 = vector.gather %4[%c0, %c0] [%20], %cst_3, %cst_2 : tensor<4096x8192xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> loc(callsite("repro.mlir":15:20 at "repro.mlir":5:3))
    %22 = arith.negf %21 : vector<1x4xf32> loc(callsite("repro.mlir":16:12 at "repro.mlir":5:3))
    %23 = arith.select %18, %cst_2, %22 : vector<1x4xi1>, vector<1x4xf32> loc(callsite("repro.mlir":17:12 at "repro.mlir":5:3))
    %24 = arith.divf %23, %cst_1 : vector<1x4xf32> loc(callsite("repro.mlir":24:12 at "repro.mlir":5:3))
    %25 = vector.multi_reduction <add>, %24, %15 [0] : vector<1x4xf32> to vector<4xf32> loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
    %26 = vector.transfer_write %25, %arg1[%c0] {in_bounds = [true]} : vector<4xf32>, tensor<4xf32> loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
    scf.yield %26 : tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  } loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %9 = vector.transfer_read %8[%c0], %cst_6 {in_bounds = [true]} : tensor<4xf32>, vector<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %10 = vector.extractelement %cst[] : vector<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %11 = vector.multi_reduction <add>, %9, %10 [0] : vector<4xf32> to f32 loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
  %12 = vector.broadcast %11 : f32 to vector<f32> loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
  %13 = vector.transfer_write %12, %3[] : vector<f32>, tensor<f32> loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
  flow.dispatch.tensor.store %13, %2, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  return loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
} loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))

In particular take this part before vectorization

%11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"]} ins(%extracted_slice : tensor<1x4xi64>) outs(%arg1 : tensor<4xf32>) {
    ^bb0(%in: i64 loc("repro.mlir":11:10), %out: f32 loc("repro.mlir":23:20)):
      %12 = linalg.index 0 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
      %13 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%12, %arg0) loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))

After vectorization that seems to become

  %8 = scf.for %arg0 = %c0 to %c1024 step %c1 iter_args(%arg1 = %7) -> (tensor<4xf32>) {
    %14 = vector.transfer_read %expanded[%arg0, %c0], %c0_i64 {in_bounds = [true, true]} : tensor<1024x4xi64>, vector<1x4xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %15 = vector.transfer_read %arg1[%c0], %cst_6 {in_bounds = [true]} : tensor<4xf32>, vector<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %16 = vector.broadcast %arg0 : index to vector<1x4xindex> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %17 = arith.index_cast %14 : vector<1x4xi64> to vector<1x4xindex> loc(callsite("repro.mlir":12:12 at "repro.mlir":5:3))

If I am reading this correctly, %17 should be [%arg0, %arg0 + 1, %arg0 + 2, %arg0 + 3], instead it is being [%arg0, %arg0, %arg0, %arg0].

@dcaballe could you please verify.

dcaballe commented 1 year ago

So, we are vectorizing dim0_size=1, dim1_size=4 with vector<1x4>. %12 = linalg.index 0 refers to dim0_size=1 so it's going to be zero all the time, which would make the affine.apply result to be [arg0, %arg0, arg0, arg0]. That looks ok to me.

MaheshRavishankar commented 1 year ago

Actually it was an issue with the split reduction pass. It doesnt seem to handle operations with indexing semantics correctly.

// -----// IR Dump Before LLVMCPUSplitReduction (iree-llvmcpu-split-reduction) //----- //
func.func @forward_dispatch_0_generic_4096_i64xf32() {
  %c0 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c-100 = arith.constant -100 : index loc(callsite("repro.mlir":7:14 at "repro.mlir":5:3))
  %cst = arith.constant 4.096000e+03 : f32 loc(callsite("repro.mlir":8:14 at "repro.mlir":5:3))
  %cst_0 = arith.constant 0.000000e+00 : f32 loc(callsite("repro.mlir":6:12 at "repro.mlir":5:3))
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> loc("repro.mlir":5:3)
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xi64>> loc("repro.mlir":5:3)
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %3 = flow.dispatch.tensor.load %2, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<writeonly:tensor<f32>> -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 8192], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> -> tensor<4096x8192xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %5 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xi64>> -> tensor<4096xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[], [], [], []]>} ins(%cst_0 : f32) outs(%3 : tensor<f32>) -> tensor<f32> loc(callsite("repro.mlir":21:10 at "repro.mlir":5:3))
  %7 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%5 : tensor<4096xi64>) outs(%6 : tensor<f32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0], [0], [4], [0]]>} {
  ^bb0(%in: i64 loc("repro.mlir":11:10), %out: f32 loc("repro.mlir":23:20)):
    %8 = linalg.index 0 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
    %9 = arith.index_cast %in : i64 to index loc(callsite("repro.mlir":12:12 at "repro.mlir":5:3))
    %10 = arith.cmpi eq, %9, %c-100 : index loc(callsite("repro.mlir":13:12 at "repro.mlir":5:3))
    %extracted = tensor.extract %4[%8, %9] : tensor<4096x8192xf32> loc(callsite("repro.mlir":15:20 at "repro.mlir":5:3))
    %11 = arith.negf %extracted : f32 loc(callsite("repro.mlir":16:12 at "repro.mlir":5:3))
    %12 = arith.select %10, %cst_0, %11 : f32 loc(callsite("repro.mlir":17:12 at "repro.mlir":5:3))
    %13 = arith.divf %12, %cst : f32 loc(callsite("repro.mlir":24:12 at "repro.mlir":5:3))
    %14 = arith.addf %13, %out : f32 loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
    linalg.yield %14 : f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  } -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  flow.dispatch.tensor.store %7, %2, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  return loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
} loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))

// -----// IR Dump After LLVMCPUSplitReduction (iree-llvmcpu-split-reduction) //----- //
func.func @forward_dispatch_0_generic_4096_i64xf32() {
  %c1 = arith.constant 1 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0_0 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0_1 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c-100 = arith.constant -100 : index loc(callsite("repro.mlir":7:14 at "repro.mlir":5:3))
  %cst = arith.constant 4.096000e+03 : f32 loc(callsite("repro.mlir":8:14 at "repro.mlir":5:3))
  %cst_2 = arith.constant 0.000000e+00 : f32 loc(callsite("repro.mlir":6:12 at "repro.mlir":5:3))
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0_1) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> loc("repro.mlir":5:3)
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0_1) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xi64>> loc("repro.mlir":5:3)
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0_1) : !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %3 = flow.dispatch.tensor.load %2, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<writeonly:tensor<f32>> -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 8192], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x8192xf32>> -> tensor<4096x8192xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %5 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xi64>> -> tensor<4096xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[], [], [], []]>} ins(%cst_2 : f32) outs(%3 : tensor<f32>) -> tensor<f32> loc(callsite("repro.mlir":21:10 at "repro.mlir":5:3))
  %c0_3 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c4096 = arith.constant 4096 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %extracted_slice = tensor.extract_slice %5[0] [4096] [1] : tensor<4096xi64> to tensor<4096xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %extracted_slice_4 = tensor.extract_slice %6[] [] [] : tensor<f32> to tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %expanded = tensor.expand_shape %extracted_slice [[0, 1]] : tensor<4096xi64> into tensor<1024x4xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %7 = tensor.empty() : tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %cst_5 = arith.constant 0.000000e+00 : f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %8 = linalg.fill ins(%cst_5 : f32) outs(%7 : tensor<4xf32>) -> tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c0_6 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %c1024 = arith.constant 1024 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %9 = scf.for %arg0 = %c0_6 to %c1024 step %c1 iter_args(%arg1 = %8) -> (tensor<4xf32>) {
    %c0_7 = arith.constant 0 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %c4 = arith.constant 4 : index loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %extracted_slice_8 = tensor.extract_slice %expanded[%arg0, 0] [1, 4] [1, 1] : tensor<1024x4xi64> to tensor<1x4xi64> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %extracted_slice_9 = tensor.extract_slice %arg1[0] [4] [1] : tensor<4xf32> to tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"]} ins(%extracted_slice_8 : tensor<1x4xi64>) outs(%extracted_slice_9 : tensor<4xf32>) {
    ^bb0(%in: i64 loc("repro.mlir":11:10), %out: f32 loc("repro.mlir":23:20)):
      %12 = linalg.index 0 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
      %13 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%12, %arg0) loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
      %14 = arith.index_cast %in : i64 to index loc(callsite("repro.mlir":12:12 at "repro.mlir":5:3))
      %15 = arith.cmpi eq, %14, %c-100 : index loc(callsite("repro.mlir":13:12 at "repro.mlir":5:3))
      %extracted = tensor.extract %4[%13, %14] : tensor<4096x8192xf32> loc(callsite("repro.mlir":15:20 at "repro.mlir":5:3))
      %16 = arith.negf %extracted : f32 loc(callsite("repro.mlir":16:12 at "repro.mlir":5:3))
      %17 = arith.select %15, %cst_2, %16 : f32 loc(callsite("repro.mlir":17:12 at "repro.mlir":5:3))
      %18 = arith.divf %17, %cst : f32 loc(callsite("repro.mlir":24:12 at "repro.mlir":5:3))
      %19 = arith.addf %18, %out : f32 loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
      linalg.yield %19 : f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    } -> tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    %inserted_slice = tensor.insert_slice %11 into %arg1[0] [4] [1] : tensor<4xf32> into tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
    scf.yield %inserted_slice : tensor<4xf32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  } loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  %10 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%9 : tensor<4xf32>) outs(%extracted_slice_4 : tensor<f32>) {
  ^bb0(%in: f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3)), %out: f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))):
    %11 = arith.addf %in, %out : f32 loc(callsite("repro.mlir":25:12 at "repro.mlir":5:3))
    linalg.yield %11 : f32 loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  } -> tensor<f32> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  flow.dispatch.tensor.store %10, %2, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:tensor<f32>> loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
  return loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))
} loc(callsite("repro.mlir":22:10 at "repro.mlir":5:3))

Basically in the original code

    %8 = linalg.index 0 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))

must be replaced with

     %8 = linalg.index 0 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
     %9 = linalg.index 0 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":5:3))
     %10 = affine.apply affine.map<()[s0, s1] -> (s0 + s1)>()[%8, %9] 

I sent a PR to disable split reduction for ops with indexing semantics. That needs to be fixed though.

@dcaballe (or @vmurali who was the original author of this pass), maybe we need to do an overhaul of this pass to not use the linalg::splitReduction method which is a total hack.