iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 571 forks source link

large vector sizes failure - cpu compilation - quantised models #18005

Open PhaneeshB opened 1 month ago

PhaneeshB commented 1 month ago

What happened?

On compiling a model with int8 quantization one of the dispatches fails to compile with the following error:

error: One or more operations with large vector sizes (16384 bytes) were found 

Min repro adapted from the failing dispatch:

module {
  func.func @largeVectorMinRepro(%arg0: tensor<1x320x65x65xi8>) -> tensor<1x320x1x1xf32> {
        %cst = arith.constant 1.250000e-01 : f32
        %cst_0 = arith.constant 0.000000e+00 : f32
        %c5408000 = arith.constant 5408000 : index
        %c0 = arith.constant 0 : index
        %3 = tensor.empty() : tensor<1x320x1x1xf32>
        %4 = tensor.empty() : tensor<65x65xf32>
        %5 = tensor.empty() : tensor<1x320x65x65xf32>
        %6 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<1x320x1x1xf32>) -> tensor<1x320x1x1xf32>
        %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x320x65x65xi8>) outs(%5 : tensor<1x320x65x65xf32>) {
        ^bb0(%in: i8, %out: f32):
          %9 = arith.extsi %in : i8 to i32
          %10 = arith.sitofp %9 : i32 to f32
          %11 = arith.mulf %10, %cst : f32
          linalg.yield %11 : f32
        } -> tensor<1x320x65x65xf32>
        %8 = linalg.pooling_nchw_sum  ins(%7, %4 : tensor<1x320x65x65xf32>, tensor<65x65xf32>) outs(%6 : tensor<1x320x1x1xf32>) -> tensor<1x320x1x1xf32>
    return %8 : tensor<1x320x1x1xf32>
  }
}

compile command : iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu largevectorissue.minrepro.mlir -o test.vmfb

host issue here

Steps to reproduce your issue

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

MaheshRavishankar commented 1 month ago

@pashu123 this seems like a codegen bug. Please take a look

hanhanW commented 1 month ago

I took a look this morning, but I did not get time to write down my observation. I'll do it soon

hanhanW commented 1 month ago

Here is the IR before vectorization. It is very similar to what @pashu123 and I saw in broadcast + mmt4d fusion. The dequant op is not fused into the reduction loops, so it ends up with a large vector size and a large stack buffer. Based on one of discussion we had few weeks ago, it's worth to fuse the dequant op into reduction loop even there are redundant computation. One of goals is to have less memory load/store in the compute body, so we want to fuse them into reduction loops. This is also what we've done in llama2 performance burn down (i.e., i16i4i32 ukernel).

So I think the fix could be updating LLVMCPUTile to LLVMCPUTileReductionAndFuseInputOperands. So we will be able to fuse input operands into reduction loops. It will be needed for our new mmt4d pipeline. @pashu123 perhaps you can implement the pass (or update LLVMCPUTile) and use it in the convolution pipeline?

// -----// IR Dump Before GenericVectorization (iree-codegen-generic-vectorization) //----- //
func.func @largeVectorMinRepro_dispatch_0_pooling_nchw_sum_1x320x1x1x65x65_f32() attributes {translation_info = #iree_codegen.translation_info<CPUConvTileAndDecomposeExpert>} {
  %c5 = arith.constant 5 : index
  %c1 = arith.constant 1 : index
  %c65 = arith.constant 65 : index
  %c16 = arith.constant 16 : index
  %c32 = arith.constant 32 : index
  %c320 = arith.constant 320 : index
  %cst = arith.constant 1.250000e-01 : f32
  %cst_0 = arith.constant 0.000000e+00 : f32
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x320x65x65xi8>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x320x1x1xf32>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %2 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
  %3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
  scf.for %arg0 = %2 to %c320 step %3 {
    %4 = flow.dispatch.tensor.load %1, offsets = [0, %arg0, 0, 0], sizes = [1, 32, 1, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<1x320x1x1xf32>> -> tensor<1x32x1x1xf32>
    %5 = flow.dispatch.tensor.load %0, offsets = [0, %arg0, 0, 0], sizes = [1, 32, 65, 65], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x320x65x65xi8>> -> tensor<1x32x65x65xi8>
    %6 = scf.for %arg1 = %c0 to %c32 step %c16 iter_args(%arg2 = %4) -> (tensor<1x32x1x1xf32>) {
      %extracted_slice = tensor.extract_slice %5[0, %arg1, 0, 0] [1, 16, 65, 65] [1, 1, 1, 1] : tensor<1x32x65x65xi8> to tensor<1x16x65x65xi8>
      %7 = tensor.empty() : tensor<1x16x65x65xf32>
      %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice : tensor<1x16x65x65xi8>) outs(%7 : tensor<1x16x65x65xf32>) {
      ^bb0(%in: i8, %out: f32):
        %11 = arith.extsi %in : i8 to i32
        %12 = arith.sitofp %11 : i32 to f32
        %13 = arith.mulf %12, %cst : f32
        linalg.yield %13 : f32
      } -> tensor<1x16x65x65xf32>
      %extracted_slice_1 = tensor.extract_slice %arg2[0, %arg1, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x32x1x1xf32> to tensor<1x16x1x1xf32>
      %9 = linalg.fill ins(%cst_0 : f32) outs(%extracted_slice_1 : tensor<1x16x1x1xf32>) -> tensor<1x16x1x1xf32>
      %10 = scf.for %arg3 = %c0 to %c65 step %c1 iter_args(%arg4 = %9) -> (tensor<1x16x1x1xf32>) {
        %11 = scf.for %arg5 = %c0 to %c65 step %c5 iter_args(%arg6 = %arg4) -> (tensor<1x16x1x1xf32>) {
          %extracted_slice_2 = tensor.extract_slice %8[0, 0, %arg3, %arg5] [1, 16, 1, 5] [1, 1, 1, 1] : tensor<1x16x65x65xf32> to tensor<1x16x1x5xf32>
          %12 = tensor.empty() : tensor<1x5xf32>
          %extracted_slice_3 = tensor.extract_slice %extracted_slice_2[0, 0, 0, 0] [1, 16, 1, 5] [1, 1, 1, 1] : tensor<1x16x1x5xf32> to tensor<1x16x5xf32>
          %extracted_slice_4 = tensor.extract_slice %12[0, 0] [1, 5] [1, 1] : tensor<1x5xf32> to tensor<5xf32>
          %extracted_slice_5 = tensor.extract_slice %arg6[0, 0, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x16x1x1xf32> to tensor<1x16x1xf32>
          %13 = linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%extracted_slice_3, %extracted_slice_4 : tensor<1x16x5xf32>, tensor<5xf32>) outs(%extracted_slice_5 : tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
          %inserted_slice_6 = tensor.insert_slice %13 into %arg6[0, 0, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x16x1xf32> into tensor<1x16x1x1xf32>
          scf.yield %inserted_slice_6 : tensor<1x16x1x1xf32>
        }
        scf.yield %11 : tensor<1x16x1x1xf32>
      }
      %inserted_slice = tensor.insert_slice %10 into %arg2[0, %arg1, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x16x1x1xf32> into tensor<1x32x1x1xf32>
      scf.yield %inserted_slice : tensor<1x32x1x1xf32>
    }
    flow.dispatch.tensor.store %6, %1, offsets = [0, %arg0, 0, 0], sizes = [1, 32, 1, 1], strides = [1, 1, 1, 1] : tensor<1x32x1x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x320x1x1xf32>>
  }
  return
}
pashu123 commented 1 month ago

Here is the IR before vectorization. It is very similar to what @pashu123 and I saw in broadcast + mmt4d fusion. The dequant op is not fused into the reduction loops, so it ends up with a large vector size and a large stack buffer. Based on one of discussion we had few weeks ago, it's worth to fuse the dequant op into reduction loop even there are redundant computation. One of goals is to have less memory load/store in the compute body, so we want to fuse them into reduction loops. This is also what we've done in llama2 performance burn down (i.e., i16i4i32 ukernel).

So I think the fix could be updating LLVMCPUTile to LLVMCPUTileReductionAndFuseInputOperands. So we will be able to fuse input operands into reduction loops. It will be needed for our new mmt4d pipeline. @pashu123 perhaps you can implement the pass (or update LLVMCPUTile) and use it in the convolution pipeline?

// -----// IR Dump Before GenericVectorization (iree-codegen-generic-vectorization) //----- //
func.func @largeVectorMinRepro_dispatch_0_pooling_nchw_sum_1x320x1x1x65x65_f32() attributes {translation_info = #iree_codegen.translation_info<CPUConvTileAndDecomposeExpert>} {
  %c5 = arith.constant 5 : index
  %c1 = arith.constant 1 : index
  %c65 = arith.constant 65 : index
  %c16 = arith.constant 16 : index
  %c32 = arith.constant 32 : index
  %c320 = arith.constant 320 : index
  %cst = arith.constant 1.250000e-01 : f32
  %cst_0 = arith.constant 0.000000e+00 : f32
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x320x65x65xi8>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x320x1x1xf32>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %2 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
  %3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
  scf.for %arg0 = %2 to %c320 step %3 {
    %4 = flow.dispatch.tensor.load %1, offsets = [0, %arg0, 0, 0], sizes = [1, 32, 1, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<1x320x1x1xf32>> -> tensor<1x32x1x1xf32>
    %5 = flow.dispatch.tensor.load %0, offsets = [0, %arg0, 0, 0], sizes = [1, 32, 65, 65], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x320x65x65xi8>> -> tensor<1x32x65x65xi8>
    %6 = scf.for %arg1 = %c0 to %c32 step %c16 iter_args(%arg2 = %4) -> (tensor<1x32x1x1xf32>) {
      %extracted_slice = tensor.extract_slice %5[0, %arg1, 0, 0] [1, 16, 65, 65] [1, 1, 1, 1] : tensor<1x32x65x65xi8> to tensor<1x16x65x65xi8>
      %7 = tensor.empty() : tensor<1x16x65x65xf32>
      %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice : tensor<1x16x65x65xi8>) outs(%7 : tensor<1x16x65x65xf32>) {
      ^bb0(%in: i8, %out: f32):
        %11 = arith.extsi %in : i8 to i32
        %12 = arith.sitofp %11 : i32 to f32
        %13 = arith.mulf %12, %cst : f32
        linalg.yield %13 : f32
      } -> tensor<1x16x65x65xf32>
      %extracted_slice_1 = tensor.extract_slice %arg2[0, %arg1, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x32x1x1xf32> to tensor<1x16x1x1xf32>
      %9 = linalg.fill ins(%cst_0 : f32) outs(%extracted_slice_1 : tensor<1x16x1x1xf32>) -> tensor<1x16x1x1xf32>
      %10 = scf.for %arg3 = %c0 to %c65 step %c1 iter_args(%arg4 = %9) -> (tensor<1x16x1x1xf32>) {
        %11 = scf.for %arg5 = %c0 to %c65 step %c5 iter_args(%arg6 = %arg4) -> (tensor<1x16x1x1xf32>) {
          %extracted_slice_2 = tensor.extract_slice %8[0, 0, %arg3, %arg5] [1, 16, 1, 5] [1, 1, 1, 1] : tensor<1x16x65x65xf32> to tensor<1x16x1x5xf32>
          %12 = tensor.empty() : tensor<1x5xf32>
          %extracted_slice_3 = tensor.extract_slice %extracted_slice_2[0, 0, 0, 0] [1, 16, 1, 5] [1, 1, 1, 1] : tensor<1x16x1x5xf32> to tensor<1x16x5xf32>
          %extracted_slice_4 = tensor.extract_slice %12[0, 0] [1, 5] [1, 1] : tensor<1x5xf32> to tensor<5xf32>
          %extracted_slice_5 = tensor.extract_slice %arg6[0, 0, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x16x1x1xf32> to tensor<1x16x1xf32>
          %13 = linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%extracted_slice_3, %extracted_slice_4 : tensor<1x16x5xf32>, tensor<5xf32>) outs(%extracted_slice_5 : tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
          %inserted_slice_6 = tensor.insert_slice %13 into %arg6[0, 0, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x16x1xf32> into tensor<1x16x1x1xf32>
          scf.yield %inserted_slice_6 : tensor<1x16x1x1xf32>
        }
        scf.yield %11 : tensor<1x16x1x1xf32>
      }
      %inserted_slice = tensor.insert_slice %10 into %arg2[0, %arg1, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x16x1x1xf32> into tensor<1x32x1x1xf32>
      scf.yield %inserted_slice : tensor<1x32x1x1xf32>
    }
    flow.dispatch.tensor.store %6, %1, offsets = [0, %arg0, 0, 0], sizes = [1, 32, 1, 1], strides = [1, 1, 1, 1] : tensor<1x32x1x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x320x1x1xf32>>
  }
  return
}

Why can't we just add funcPassManager.addPass(createLLVMCPUTileAndFusePass(tilingConfig.getVectorReductionLevel())); after https://github.com/iree-org/iree/blob/456d80c51930ccc03ce0488e98238e5e0a14b403/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp#L448

This is what I see before generic vectorization and the program successfully compiles:

   func.func @largeVectorMinRepro_dispatch_0_pooling_nchw_sum_1x320x1x1x65x65_f32() attributes {translation_info = #iree_codegen.translation_info<CPUConvTileAndDecomposeExpert>} {
     %c5 = arith.constant 5 : index
     %c1 = arith.constant 1 : index
     %c65 = arith.constant 65 : index
     %c16 = arith.constant 16 : index
     %c32 = arith.constant 32 : index
     %c320 = arith.constant 320 : index
     %cst = arith.constant 1.250000e-01 : f32
     %cst_0 = arith.constant 0.000000e+00 : f32
     %c0 = arith.constant 0 : index
     %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x320x65x65xi8>>
     %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x320x1x1xf32>>
     %workgroup_id_x = hal.interface.workgroup.id[0] : index
     %workgroup_count_x = hal.interface.workgroup.count[0] : index
     %2 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
     %3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
     scf.for %arg0 = %2 to %c320 step %3 {
       %4 = flow.dispatch.tensor.load %1, offsets = [0, %arg0, 0, 0], sizes = [1, 32, 1, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<1x320x1x1xf32>> -> tensor<1x32x1x1xf32>
       %5 = flow.dispatch.tensor.load %0, offsets = [0, %arg0, 0, 0], sizes = [1, 32, 65, 65], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x320x65x65xi8>> -> tensor<1x32x65x65xi8>
       %6 = scf.for %arg1 = %c0 to %c32 step %c16 iter_args(%arg2 = %4) -> (tensor<1x32x1x1xf32>) {
         %extracted_slice = tensor.extract_slice %5[0, %arg1, 0, 0] [1, 16, 65, 65] [1, 1, 1, 1] : tensor<1x32x65x65xi8> to tensor<1x16x65x65xi8>
         %extracted_slice_1 = tensor.extract_slice %arg2[0, %arg1, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x32x1x1xf32> to tensor<1x16x1x1xf32>
         %7 = scf.for %arg3 = %c0 to %c65 step %c1 iter_args(%arg4 = %extracted_slice_1) -> (tensor<1x16x1x1xf32>) {
           %8 = scf.for %arg5 = %c0 to %c65 step %c5 iter_args(%arg6 = %arg4) -> (tensor<1x16x1x1xf32>) {
             %extracted_slice_2 = tensor.extract_slice %extracted_slice[0, 0, %arg3, %arg5] [1, 16, 1, 5] [1, 1, 1, 1] : tensor<1x16x65x65xi8> to tensor<1x16x1x5xi8>
             %9 = tensor.empty() : tensor<1x16x1x5xf32>
             %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel",         "parallel"]} ins(%extracted_slice_2 : tensor<1x16x1x5xi8>) outs(%9 : tensor<1x16x1x5xf32>) {
             ^bb0(%in: i8, %out: f32):
               %14 = arith.extsi %in : i8 to i32
               %15 = arith.sitofp %14 : i32 to f32
               %16 = arith.mulf %15, %cst : f32
               linalg.yield %16 : f32
             } -> tensor<1x16x1x5xf32>
             %11 = linalg.fill ins(%cst_0 : f32) outs(%arg6 : tensor<1x16x1x1xf32>) -> tensor<1x16x1x1xf32>
             %12 = tensor.empty() : tensor<1x5xf32>
             %extracted_slice_3 = tensor.extract_slice %10[0, 0, 0, 0] [1, 16, 1, 5] [1, 1, 1, 1] : tensor<1x16x1x5xf32> to tensor<1x16x5xf32>
             %extracted_slice_4 = tensor.extract_slice %12[0, 0] [1, 5] [1, 1] : tensor<1x5xf32> to tensor<5xf32>
             %extracted_slice_5 = tensor.extract_slice %11[0, 0, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x16x1x1xf32> to tensor<1x16x1xf32>
             %13 = linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%extracted_slice_3, %extracted_slice_4 : tensor<1x16x5xf32>, tensor<5xf32>) outs(%extr        acted_slice_5 : tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
             %inserted_slice_6 = tensor.insert_slice %13 into %11[0, 0, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x16x1xf32> into tensor<1x16x1x1xf32>
             scf.yield %inserted_slice_6 : tensor<1x16x1x1xf32>
           }
           scf.yield %8 : tensor<1x16x1x1xf32>
         }
         %inserted_slice = tensor.insert_slice %7 into %arg2[0, %arg1, 0, 0] [1, 16, 1, 1] [1, 1, 1, 1] : tensor<1x16x1x1xf32> into tensor<1x32x1x1xf32>
         scf.yield %inserted_slice : tensor<1x32x1x1xf32>
       }
       flow.dispatch.tensor.store %6, %1, offsets = [0, %arg0, 0, 0], sizes = [1, 32, 1, 1], strides = [1, 1, 1, 1] : tensor<1x32x1x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x320x1x1xf32>>
   }
hanhanW commented 1 month ago

I think there are numeric issues because you initialize the acc to zeros every time. That's why I'm saying that we should only fuse the input operands when tiling the reduction loops.

Let's take gemm as an example. Code1 is what you're doing when fuse the fill op, but what we want is Code2. Does it make sense?

Code1:

for (int i = 0; i < M; ++i)  {
  for (int j = 0; j < N; ++j) {
     for (int k = 0; k < K; ++k) {
       int64_t acc = 0; // linalg.fill
       acc = A[i][k] + B[k][j] // linalg.mamtul
       C[i][j] = acc; // scf.yield
     }
  }
}

Code2:


for (int i = 0; i < M; ++i)  {
  for (int j = 0; j < N; ++j) {
     C[i][j] = 0; // linalg.fill
     int64_t acc = C[i][j]; // iter_args(%arg6 = %arg4)
     for (int k = 0; k < K; ++k) {
       acc += A[i][k] + B[k][j] // linalg.mamtul
       C[i][j] = acc; // scf.yield
     }
  }
}

(You can try e2e tests with your suggestion, I think it will generate wrong outputs.)

AmosLewis commented 1 month ago

Got same large vector sizes failure for onnx models dpn68_vaiq/dpn92_vaiq/dpn98_vaiq/dpn107_vaiq/dpn131_vaiq/skresnet34_vaiq/skresnet18_vaiq/DeepLabV3_resnet50_vaiq_int8/RAFT_vaiq_int8/U-2-Net_vaiq_int8 in public onnx storage. Here is one of the detailed log: dpn68_vaiq_iree_failed.log

pashu123 commented 1 month ago

Got same large vector sizes failure for onnx models dpn68_vaiq/dpn92_vaiq/dpn98_vaiq/dpn107_vaiq/dpn131_vaiq/skresnet34_vaiq/skresnet18_vaiq/DeepLabV3_resnet50_vaiq_int8/RAFT_vaiq_int8/U-2-Net_vaiq_int8 in public onnx storage. Here is one of the detailed log: dpn68_vaiq_iree_failed.log

@AmosLewis Please check out the following commit and try https://github.com/iree-org/iree/pull/18114/commits/b0b3dea51146dcec66648859a7d3bc074bc49f97

AmosLewis commented 1 month ago

@AmosLewis Please check out the following commit and try b0b3dea

Why b0b3dea? It has been replaced by new commit in your pr https://github.com/iree-org/iree/pull/18114. I directly checkout to 55f161160e which is your recent change in https://github.com/iree-org/iree/pull/18114 and it still failed the same error for dpn68_vaiq model.

pashu123 commented 1 month ago

@AmosLewis Please check out the following commit and try b0b3dea

Why b0b3dea? It has been replaced by new commit in your pr #18114. I directly checkout to 55f1611 which is your recent change in #18114 and it still failed the same error for dpn68_vaiq model.

There are no issues. You can use the latest commit as well. The previous one was more tested. Could you use iree-hal-dump-executable-sources-to=... and paste the failing dispatch?

AmosLewis commented 1 month ago

@pashu123 iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu dpn68_vaiq.default.onnx.linalg.mlir > dpn68_vaiq.default.vmfb --iree-hal-dump-executable-sources-to=./dispatch module_main_graph_dispatch_34.mlir:

hal.executable public @main_graph_dispatch_34 {
  hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>) {
    hal.executable.export public @main_graph_dispatch_34_elementwise_64x56x56_f32xf32xi8 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main_graph_dispatch_34_elementwise_64x56x56_f32xf32xi8() {
        %cst = arith.constant 0.000000e+00 : f32
        %cst_0 = arith.constant -1.280000e+02 : f32
        %cst_1 = arith.constant 1.270000e+02 : f32
        %cst_2 = arith.constant 1.562500e-02 : f32
        %c2408448 = arith.constant 2408448 : index
        %c2207744 = arith.constant 2207744 : index
        %c802816 = arith.constant 802816 : index
        %0 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c2408448) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x80x56x56xf32>>
        %1 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c2207744) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<200704xi8>>
        %2 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c802816) : !flow.dispatch.tensor<writeonly:tensor<64x56x56xi8>>
        %3 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [200704], strides = [1] : !flow.dispatch.tensor<readonly:tensor<200704xi8>> -> tensor<200704xi8>
        %4 = tensor.empty() : tensor<64x56x56xi8>
        %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x80x56x56xf32>> -> tensor<64x56x56xf32>
        %6 = tensor.empty() : tensor<200704xf32>
        %7 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%3 : tensor<200704xi8>) outs(%6 : tensor<200704xf32>) {
        ^bb0(%in: i8, %out: f32):
          %9 = arith.extsi %in : i8 to i32
          %10 = arith.sitofp %9 : i32 to f32
          %11 = arith.mulf %10, %cst_2 : f32
          linalg.yield %11 : f32
        } -> tensor<200704xf32>
        %expanded = tensor.expand_shape %7 [[0, 1, 2]] output_shape [64, 56, 56] : tensor<200704xf32> into tensor<64x56x56xf32>
        %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded, %5 : tensor<64x56x56xf32>, tensor<64x56x56xf32>) outs(%4 : tensor<64x56x56xi8>) {
        ^bb0(%in: f32, %in_3: f32, %out: i8):
          %9 = arith.divf %in_3, %cst_2 : f32
          %10 = math.roundeven %9 : f32
          %11 = arith.addf %10, %cst : f32
          %12 = arith.maximumf %11, %cst_0 : f32
          %13 = arith.minimumf %12, %cst_1 : f32
          %14 = arith.fptosi %13 : f32 to i8
          %15 = arith.extsi %14 : i8 to i32
          %16 = arith.sitofp %15 : i32 to f32
          %17 = arith.mulf %16, %cst_2 : f32
          %18 = arith.addf %in, %17 : f32
          %19 = arith.divf %18, %cst_2 : f32
          %20 = math.roundeven %19 : f32
          %21 = arith.addf %20, %cst : f32
          %22 = arith.maximumf %21, %cst_0 : f32
          %23 = arith.minimumf %22, %cst_1 : f32
          %24 = arith.fptosi %23 : f32 to i8
          linalg.yield %24 : i8
        } -> tensor<64x56x56xi8>
        flow.dispatch.tensor.store %8, %2, offsets = [0, 0, 0], sizes = [64, 56, 56], strides = [1, 1, 1] : tensor<64x56x56xi8> -> !flow.dispatch.tensor<writeonly:tensor<64x56x56xi8>>
        return
      }
    }
  }
}

module_main_graph_dispatch_47.mlir

hal.executable public @main_graph_dispatch_47 {
  hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>) {
    hal.executable.export public @main_graph_dispatch_47_elementwise_64x56x56_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main_graph_dispatch_47_elementwise_64x56x56_f32() {
        %cst = arith.constant 0.000000e+00 : f32
        %cst_0 = arith.constant -1.280000e+02 : f32
        %cst_1 = arith.constant 1.270000e+02 : f32
        %cst_2 = arith.constant 1.562500e-02 : f32
        %c2007040 = arith.constant 2007040 : index
        %c802816 = arith.constant 802816 : index
        %c1003520 = arith.constant 1003520 : index
        %0 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c2007040) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x80x56x56xf32>>
        %1 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c802816) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<200704xi8>>
        %2 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c1003520) : !flow.dispatch.tensor<writeonly:tensor<64x56x56xf32>>
        %3 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [200704], strides = [1] : !flow.dispatch.tensor<readonly:tensor<200704xi8>> -> tensor<200704xi8>
        %4 = tensor.empty() : tensor<64x56x56xf32>
        %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x80x56x56xf32>> -> tensor<64x56x56xf32>
        %6 = tensor.empty() : tensor<200704xf32>
        %7 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%3 : tensor<200704xi8>) outs(%6 : tensor<200704xf32>) {
        ^bb0(%in: i8, %out: f32):
          %9 = arith.extsi %in : i8 to i32
          %10 = arith.sitofp %9 : i32 to f32
          %11 = arith.mulf %10, %cst_2 : f32
          linalg.yield %11 : f32
        } -> tensor<200704xf32>
        %expanded = tensor.expand_shape %7 [[0, 1, 2]] output_shape [64, 56, 56] : tensor<200704xf32> into tensor<64x56x56xf32>
        %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded, %5 : tensor<64x56x56xf32>, tensor<64x56x56xf32>) outs(%4 : tensor<64x56x56xf32>) {
        ^bb0(%in: f32, %in_3: f32, %out: f32):
          %9 = arith.divf %in_3, %cst_2 : f32
          %10 = math.roundeven %9 : f32
          %11 = arith.addf %10, %cst : f32
          %12 = arith.maximumf %11, %cst_0 : f32
          %13 = arith.minimumf %12, %cst_1 : f32
          %14 = arith.fptosi %13 : f32 to i8
          %15 = arith.extsi %14 : i8 to i32
          %16 = arith.sitofp %15 : i32 to f32
          %17 = arith.mulf %16, %cst_2 : f32
          %18 = arith.addf %in, %17 : f32
          %19 = arith.divf %18, %cst_2 : f32
          %20 = math.roundeven %19 : f32
          %21 = arith.addf %20, %cst : f32
          %22 = arith.maximumf %21, %cst_0 : f32
          %23 = arith.minimumf %22, %cst_1 : f32
          %24 = arith.fptosi %23 : f32 to i8
          %25 = arith.extsi %24 : i8 to i32
          %26 = arith.sitofp %25 : i32 to f32
          %27 = arith.mulf %26, %cst_2 : f32
          linalg.yield %27 : f32
        } -> tensor<64x56x56xf32>
        flow.dispatch.tensor.store %8, %2, offsets = [0, 0, 0], sizes = [64, 56, 56], strides = [1, 1, 1] : tensor<64x56x56xf32> -> !flow.dispatch.tensor<writeonly:tensor<64x56x56xf32>>
        return
      }
    }
  }
}
pashu123 commented 1 month ago

@pashu123 iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu dpn68_vaiq.default.onnx.linalg.mlir > dpn68_vaiq.default.vmfb --iree-hal-dump-executable-sources-to=./dispatch module_main_graph_dispatch_34.mlir:

hal.executable public @main_graph_dispatch_34 {
  hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>) {
    hal.executable.export public @main_graph_dispatch_34_elementwise_64x56x56_f32xf32xi8 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main_graph_dispatch_34_elementwise_64x56x56_f32xf32xi8() {
        %cst = arith.constant 0.000000e+00 : f32
        %cst_0 = arith.constant -1.280000e+02 : f32
        %cst_1 = arith.constant 1.270000e+02 : f32
        %cst_2 = arith.constant 1.562500e-02 : f32
        %c2408448 = arith.constant 2408448 : index
        %c2207744 = arith.constant 2207744 : index
        %c802816 = arith.constant 802816 : index
        %0 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c2408448) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x80x56x56xf32>>
        %1 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c2207744) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<200704xi8>>
        %2 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c802816) : !flow.dispatch.tensor<writeonly:tensor<64x56x56xi8>>
        %3 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [200704], strides = [1] : !flow.dispatch.tensor<readonly:tensor<200704xi8>> -> tensor<200704xi8>
        %4 = tensor.empty() : tensor<64x56x56xi8>
        %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x80x56x56xf32>> -> tensor<64x56x56xf32>
        %6 = tensor.empty() : tensor<200704xf32>
        %7 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%3 : tensor<200704xi8>) outs(%6 : tensor<200704xf32>) {
        ^bb0(%in: i8, %out: f32):
          %9 = arith.extsi %in : i8 to i32
          %10 = arith.sitofp %9 : i32 to f32
          %11 = arith.mulf %10, %cst_2 : f32
          linalg.yield %11 : f32
        } -> tensor<200704xf32>
        %expanded = tensor.expand_shape %7 [[0, 1, 2]] output_shape [64, 56, 56] : tensor<200704xf32> into tensor<64x56x56xf32>
        %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded, %5 : tensor<64x56x56xf32>, tensor<64x56x56xf32>) outs(%4 : tensor<64x56x56xi8>) {
        ^bb0(%in: f32, %in_3: f32, %out: i8):
          %9 = arith.divf %in_3, %cst_2 : f32
          %10 = math.roundeven %9 : f32
          %11 = arith.addf %10, %cst : f32
          %12 = arith.maximumf %11, %cst_0 : f32
          %13 = arith.minimumf %12, %cst_1 : f32
          %14 = arith.fptosi %13 : f32 to i8
          %15 = arith.extsi %14 : i8 to i32
          %16 = arith.sitofp %15 : i32 to f32
          %17 = arith.mulf %16, %cst_2 : f32
          %18 = arith.addf %in, %17 : f32
          %19 = arith.divf %18, %cst_2 : f32
          %20 = math.roundeven %19 : f32
          %21 = arith.addf %20, %cst : f32
          %22 = arith.maximumf %21, %cst_0 : f32
          %23 = arith.minimumf %22, %cst_1 : f32
          %24 = arith.fptosi %23 : f32 to i8
          linalg.yield %24 : i8
        } -> tensor<64x56x56xi8>
        flow.dispatch.tensor.store %8, %2, offsets = [0, 0, 0], sizes = [64, 56, 56], strides = [1, 1, 1] : tensor<64x56x56xi8> -> !flow.dispatch.tensor<writeonly:tensor<64x56x56xi8>>
        return
      }
    }
  }
}

module_main_graph_dispatch_47.mlir

hal.executable public @main_graph_dispatch_47 {
  hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>) {
    hal.executable.export public @main_graph_dispatch_47_elementwise_64x56x56_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main_graph_dispatch_47_elementwise_64x56x56_f32() {
        %cst = arith.constant 0.000000e+00 : f32
        %cst_0 = arith.constant -1.280000e+02 : f32
        %cst_1 = arith.constant 1.270000e+02 : f32
        %cst_2 = arith.constant 1.562500e-02 : f32
        %c2007040 = arith.constant 2007040 : index
        %c802816 = arith.constant 802816 : index
        %c1003520 = arith.constant 1003520 : index
        %0 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c2007040) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x80x56x56xf32>>
        %1 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c802816) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<200704xi8>>
        %2 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c1003520) : !flow.dispatch.tensor<writeonly:tensor<64x56x56xf32>>
        %3 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [200704], strides = [1] : !flow.dispatch.tensor<readonly:tensor<200704xi8>> -> tensor<200704xi8>
        %4 = tensor.empty() : tensor<64x56x56xf32>
        %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x80x56x56xf32>> -> tensor<64x56x56xf32>
        %6 = tensor.empty() : tensor<200704xf32>
        %7 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%3 : tensor<200704xi8>) outs(%6 : tensor<200704xf32>) {
        ^bb0(%in: i8, %out: f32):
          %9 = arith.extsi %in : i8 to i32
          %10 = arith.sitofp %9 : i32 to f32
          %11 = arith.mulf %10, %cst_2 : f32
          linalg.yield %11 : f32
        } -> tensor<200704xf32>
        %expanded = tensor.expand_shape %7 [[0, 1, 2]] output_shape [64, 56, 56] : tensor<200704xf32> into tensor<64x56x56xf32>
        %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded, %5 : tensor<64x56x56xf32>, tensor<64x56x56xf32>) outs(%4 : tensor<64x56x56xf32>) {
        ^bb0(%in: f32, %in_3: f32, %out: f32):
          %9 = arith.divf %in_3, %cst_2 : f32
          %10 = math.roundeven %9 : f32
          %11 = arith.addf %10, %cst : f32
          %12 = arith.maximumf %11, %cst_0 : f32
          %13 = arith.minimumf %12, %cst_1 : f32
          %14 = arith.fptosi %13 : f32 to i8
          %15 = arith.extsi %14 : i8 to i32
          %16 = arith.sitofp %15 : i32 to f32
          %17 = arith.mulf %16, %cst_2 : f32
          %18 = arith.addf %in, %17 : f32
          %19 = arith.divf %18, %cst_2 : f32
          %20 = math.roundeven %19 : f32
          %21 = arith.addf %20, %cst : f32
          %22 = arith.maximumf %21, %cst_0 : f32
          %23 = arith.minimumf %22, %cst_1 : f32
          %24 = arith.fptosi %23 : f32 to i8
          %25 = arith.extsi %24 : i8 to i32
          %26 = arith.sitofp %25 : i32 to f32
          %27 = arith.mulf %26, %cst_2 : f32
          linalg.yield %27 : f32
        } -> tensor<64x56x56xf32>
        flow.dispatch.tensor.store %8, %2, offsets = [0, 0, 0], sizes = [64, 56, 56], strides = [1, 1, 1] : tensor<64x56x56xf32> -> !flow.dispatch.tensor<writeonly:tensor<64x56x56xf32>>
        return
      }
    }
  }
}

Looking into the kernel, the two generics can be fused if the tensor expand_shape is propagated upward or downward, and the first operand of the second generic is either expanded or squashed to 3D or 1D. @hanhanW, do you have any suggestions? My take is to enable linalg element-wise fusion on dispatches of this kind.

pashu123 commented 1 month ago

Looking into the failure https://gist.github.com/AmosLewis/35ed28904fd6e82de0c66546b18579df#file-dpn68_vaiq_iree_failed-log The problem is with the fusion of producer.

pashu123 commented 1 month ago

I created a repro IR at https://gist.github.com/pashu123/d52f974975f0ebcfa6b131d076660e70, and it was successfully compiled. After bubbling up the tensor.expand op, elementwise fusion took place.

hanhanW commented 1 month ago

@MaheshRavishankar @IanWood1 why do we have a tensor.expand_shape in between? I thought that the reshape ops become flow.reshape ops and we don't fuse them into dispatches?

If it is expected, do we apply some cleanups (like what Prashant mentioned) at flow level? It is beneficial to all the backends.

IanWood1 commented 1 month ago

We shouldn't have them in between, they dont get cloned. Its possible that this came from CollapseDimensionsPass

I'm going to take a look at that

IanWood1 commented 1 month ago

There are a number of reasons why the backend might emit a large vector sizes failure. My understanding is that it is mainly due to codegen failing to fuse ops, and requiring extra memory to store transient data. Here are some cases i'm thinking of:

  1. op -> reshape -> op
  2. dequant op -> rank reducing extract slice -> op
  3. op -> insert slice -> op
  4. ... other harder to diagnose problems

@hanhanW do we want to try to emit a more descriptive error? Or at least have a more explicit check for ops we know cant be fused? But maybe providing a descriptive error here is more difficult than my understanding.

hanhanW commented 1 month ago

My understanding is that all the compute ops in the ssa-chain should implement TilingInterface. Otherwise, we don't have much things to do in codegen. So (1), (2) and (3) are problematic codegen input to me. It does not only happen on CPU backend, but also happens in all other backends. Basically you won't be able to distribute the workload without modifying the graph. If the conclusion is that we always want to update the graph, then the dispatch creation should generate such dispatch for backends. Thus, I think we could add VerifyDispatchRegionLegality pass to detect the case at the end of DispatchRegionCreation phase.

(I know that we have an option that fuses everything to a single dispatch, but it is not the default behavior. CPU could handle the case in a non-sense way which is very slow.)

MaheshRavishankar commented 1 month ago

It's hard to write such verifiers cause they are dependent on implementation status of codegen and not tied to any "real" constraints (like large vectors are bad, or large stack allocations are bad).

IIUC this is a bug. And we will need to investigate cause for every time we hit this error (it's basically a catch all for "something went off the rails"). Not sure we can really do better than that