iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 568 forks source link

Unknown upper bound allocation due to LinalgStrategyPeelPass #11196

Open pzread opened 1 year ago

pzread commented 1 year ago

We found that LinalgStrategyPeelPass optimizes away some affine.min used by tensor.empty, which makes a later check not be able to calculate the upper bound of memory allocation.

This is originally found when debugging a failed test iree_tf_tests/math/llvmcpu__dynamic_dim_softmax.run with a WIP change (#10770).

Here is a crafted example to reproduce the issue:

func.func @peel_dyn(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
  // The alloc_buf is uninitialized to avoid being optimized.
  %alloc_buf = tensor.empty(%dim0) : tensor<?xf32>
  %out = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
  %res = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
    iterator_types = ["parallel", "parallel"]
  } ins(%arg0, %arg1, %alloc_buf : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>) outs(%out : tensor<?x?xf32>) {
  ^bb0(%a: f32, %b: f32, %c: f32, %z: f32):
    %w = arith.maxf %a, %b : f32
    %x = arith.minf %w, %c : f32
    linalg.yield %x : f32
  } -> tensor<?x?xf32>
  return %res : tensor<?x?xf32>
}

By compiling the example with iree-compile --iree-hal-target-backends=llvm-cpu, the compiler throws an error:

error: 'memref.alloca' op expected no stack allocations without upper bound shapes

What Happened

Before LinalgStrategyPeelPass, the example is tiled into:

// -----// IR Dump Before LinalgStrategyPeelPass (iree-linalg-strategy-peel-pass) //----- //
func.func @peel_dyn_dispatch_0_generic_DxD() {
  %c1 = arith.constant 1 : index
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = hal.interface.constant.load[2] : i32
  %3 = hal.interface.constant.load[3] : i32
  %4 = arith.index_castui %0 : i32 to index
  %5 = arith.index_castui %1 : i32 to index
  %6 = arith.index_castui %2 : i32 to index
  %7 = arith.index_castui %3 : i32 to index
  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%6, %7}
  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%4, %5}
  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%6, %7}
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %11 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
  %12 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
  %13 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
  %14 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
  scf.for %arg0 = %11 to %6 step %12 {
    %15 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 64)>(%arg0)[%6]
    scf.for %arg1 = %13 to %7 step %14 {
      %16 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 64)>(%arg1)[%7]
      %17 = flow.dispatch.tensor.load %10, offsets = [%arg0, %arg1], sizes = [%15, %16], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%6, %7} -> tensor<?x?xf32>
      %18 = flow.dispatch.tensor.load %8, offsets = [%arg0, %arg1], sizes = [%15, %16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%6, %7} -> tensor<?x?xf32>
      %19 = flow.dispatch.tensor.load %9, offsets = [%arg0, %arg1], sizes = [%15, %16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%4, %5} -> tensor<?x?xf32>
      %dim = tensor.dim %18, %c0 : tensor<?x?xf32>
      %dim_0 = tensor.dim %18, %c1 : tensor<?x?xf32>
      %20 = scf.for %arg2 = %c0 to %dim step %c4 iter_args(%arg3 = %17) -> (tensor<?x?xf32>) {
        %21 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%arg2)[%dim]
        %22 = tensor.empty(%21) : tensor<?xf32>
        %23 = scf.for %arg4 = %c0 to %dim_0 step %c4 iter_args(%arg5 = %arg3) -> (tensor<?x?xf32>) {
          %24 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%arg4)[%dim_0]
          %extracted_slice = tensor.extract_slice %18[%arg2, %arg4] [%21, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %extracted_slice_1 = tensor.extract_slice %19[%arg2, %arg4] [%21, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %extracted_slice_2 = tensor.extract_slice %arg5[%arg2, %arg4] [%21, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %extracted_slice_1, %22 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>) outs(%extracted_slice_2 : tensor<?x?xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [4, 4], [0, 0]]>} {
          ^bb0(%in: f32, %in_3: f32, %in_4: f32, %out: f32):
            %26 = arith.maxf %in, %in_3 : f32
            %27 = arith.minf %26, %in_4 : f32
            linalg.yield %27 : f32
          } -> tensor<?x?xf32>
          %inserted_slice = tensor.insert_slice %25 into %arg5[%arg2, %arg4] [%21, %24] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
          scf.yield %inserted_slice : tensor<?x?xf32>
        }
        scf.yield %23 : tensor<?x?xf32>
      }
      flow.dispatch.tensor.store %20, %10, offsets = [%arg0, %arg1], sizes = [%15, %16], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%6, %7}
    }
  }
  return
}

After LinalgStrategyPeelPass, the output is:

// -----// IR Dump After LinalgStrategyPeelPass (iree-linalg-strategy-peel-pass) //----- //
func.func @peel_dyn_dispatch_0_generic_DxD() {
  %c1 = arith.constant 1 : index
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = hal.interface.constant.load[2] : i32
  %3 = hal.interface.constant.load[3] : i32
  %4 = arith.index_castui %0 : i32 to index
  %5 = arith.index_castui %1 : i32 to index
  %6 = arith.index_castui %2 : i32 to index
  %7 = arith.index_castui %3 : i32 to index
  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%6, %7}
  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%4, %5}
  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%6, %7}
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %11 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
  %12 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
  %13 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
  %14 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
  scf.for %arg0 = %11 to %6 step %12 {
    %15 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 64)>(%arg0)[%6]
    scf.for %arg1 = %13 to %7 step %14 {
      %16 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 64)>(%arg1)[%7]
      %17 = flow.dispatch.tensor.load %10, offsets = [%arg0, %arg1], sizes = [%15, %16], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%6, %7} -> tensor<?x?xf32>
      %18 = flow.dispatch.tensor.load %8, offsets = [%arg0, %arg1], sizes = [%15, %16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%6, %7} -> tensor<?x?xf32>
      %19 = flow.dispatch.tensor.load %9, offsets = [%arg0, %arg1], sizes = [%15, %16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%4, %5} -> tensor<?x?xf32>
      %dim = tensor.dim %18, %c0 : tensor<?x?xf32>
      %dim_0 = tensor.dim %18, %c1 : tensor<?x?xf32>
      %20 = affine.apply affine_map<()[s0, s1, s2] -> (s1 - (s1 - s0) mod s2)>()[%c0, %dim, %c4]
      %21 = scf.for %arg2 = %c0 to %20 step %c4 iter_args(%arg3 = %17) -> (tensor<?x?xf32>) {
        %23 = tensor.empty(%c4) : tensor<?xf32>
        %24 = affine.apply affine_map<()[s0, s1, s2] -> (s1 - (s1 - s0) mod s2)>()[%c0, %dim_0, %c4]
        %25 = scf.for %arg4 = %c0 to %24 step %c4 iter_args(%arg5 = %arg3) -> (tensor<?x?xf32>) {
          %extracted_slice = tensor.extract_slice %18[%arg2, %arg4] [%c4, %c4] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %extracted_slice_1 = tensor.extract_slice %19[%arg2, %arg4] [%c4, %c4] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %extracted_slice_2 = tensor.extract_slice %arg5[%arg2, %arg4] [%c4, %c4] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %27 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %extracted_slice_1, %23 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>) outs(%extracted_slice_2 : tensor<?x?xf32>) attrs =  {__internal_linalg_transform__ = "1", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [4, 4], [0, 0]]>} {
          ^bb0(%in: f32, %in_3: f32, %in_4: f32, %out: f32):
            %28 = arith.maxf %in, %in_3 : f32
            %29 = arith.minf %28, %in_4 : f32
            linalg.yield %29 : f32
          } -> tensor<?x?xf32>
          %inserted_slice = tensor.insert_slice %27 into %arg5[%arg2, %arg4] [%c4, %c4] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
          scf.yield %inserted_slice : tensor<?x?xf32>
        }
        %26 = scf.for %arg4 = %24 to %dim_0 step %c4 iter_args(%arg5 = %25) -> (tensor<?x?xf32>) {
          %27 = affine.apply affine_map<(d0, d1) -> (-d0 + d1)>(%arg4, %dim_0)
          %extracted_slice = tensor.extract_slice %18[%arg2, %arg4] [%c4, %27] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %extracted_slice_1 = tensor.extract_slice %19[%arg2, %arg4] [%c4, %27] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %extracted_slice_2 = tensor.extract_slice %arg5[%arg2, %arg4] [%c4, %27] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %28 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %extracted_slice_1, %23 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>) outs(%extracted_slice_2 : tensor<?x?xf32>) attrs =  {__internal_linalg_transform__ = "1", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [4, 4], [0, 0]]>} {
          ^bb0(%in: f32, %in_3: f32, %in_4: f32, %out: f32):
            %29 = arith.maxf %in, %in_3 : f32
            %30 = arith.minf %29, %in_4 : f32
            linalg.yield %30 : f32
          } -> tensor<?x?xf32>
          %inserted_slice = tensor.insert_slice %28 into %arg5[%arg2, %arg4] [%c4, %27] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
          scf.yield %inserted_slice : tensor<?x?xf32>
        }
        scf.yield %26 : tensor<?x?xf32>
      }
      %22 = scf.for %arg2 = %20 to %dim step %c4 iter_args(%arg3 = %21) -> (tensor<?x?xf32>) {
        %23 = affine.apply affine_map<(d0, d1) -> (-d0 + d1)>(%arg2, %dim)
        %24 = tensor.empty(%23) : tensor<?xf32>
        %25 = scf.for %arg4 = %c0 to %dim_0 step %c4 iter_args(%arg5 = %arg3) -> (tensor<?x?xf32>) {
          %26 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%arg4)[%dim_0]
          %extracted_slice = tensor.extract_slice %18[%arg2, %arg4] [%23, %26] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %extracted_slice_1 = tensor.extract_slice %19[%arg2, %arg4] [%23, %26] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %extracted_slice_2 = tensor.extract_slice %arg5[%arg2, %arg4] [%23, %26] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %27 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %extracted_slice_1, %24 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>) outs(%extracted_slice_2 : tensor<?x?xf32>) attrs =  {__internal_linalg_transform__ = "1", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [4, 4], [0, 0]]>} {
          ^bb0(%in: f32, %in_3: f32, %in_4: f32, %out: f32):
            %28 = arith.maxf %in, %in_3 : f32
            %29 = arith.minf %28, %in_4 : f32
            linalg.yield %29 : f32
          } -> tensor<?x?xf32>
          %inserted_slice = tensor.insert_slice %27 into %arg5[%arg2, %arg4] [%23, %26] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
          scf.yield %inserted_slice : tensor<?x?xf32>
        }
        scf.yield %25 : tensor<?x?xf32>
      }
      flow.dispatch.tensor.store %22, %10, offsets = [%arg0, %arg1], sizes = [%15, %16], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%6, %7}
    }
  }
  return
}

We can notice that the size of tensor allocation in the second loop becomes a simple affine.apply due to the optimization [1]:

%23 = affine.apply affine_map<(d0, d1) -> (-d0 + d1)>(%arg2, %dim)
%24 = tensor.empty(%23) : tensor<?xf32>

while originally it was:

%21 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%arg2)[%dim]
%22 = tensor.empty(%21) : tensor<?xf32>

This causes the later LLVMCPUCheckIRBeforeLLVMConversion fails to derive the upper bound of tensor allocation [2].

pzread commented 1 year ago

Not sure about the solution here. I think it's hard to make linalg::getConstantUpperBoundForIndex in LLVMCPUCheckIRBeforeLLVMConversion to know the upper bound of affine.apply.

If it's not easy to solve, maybe we can disable the check in LLVMCPUCheckIRBeforeLLVMConversion as a temporary solution.

@hanhanW FYI

pzread commented 1 year ago

Reopen since we decided to disable some tests for this issue

dcaballe commented 1 year ago

This problem seems to get worse when we enable vector masking (https://github.com/iree-org/iree/pull/11286). I'm now getting similar stack overflow errors when building iree-test-deps:

/usr/local/google/home/diegocaballero/iree2/tests/e2e/models/mobilenetv3_fake_weights.mlir:977:12: error: 'builtin.module' op expected total size of stack allocation is not greater than 32768 bytes, but got 74112 bytes
    %731 = mhlo.maximum %730, %262 : tensor<1x1x1x144xf32>
/usr/local/google/home/diegocaballero/iree2/build/debug/tests/e2e/matmul/e2e_matmul_direct_f32_large_split_k_llvm-cpu_local-task.mlir:12:13: error: 'builtin.module' op expected total size of 
stack allocation is not greater than 32768 bytes, but got 53504 bytes

From a compiler perspective, I wonder how is that we are not using heap allocation. It wasn't mentioned in our discussions so I guess there are some limitations that prevent that from doing so but I'm pretty surprise that we can get things moving only with stack allocations. Anybody knows?

dcaballe commented 1 year ago

There seems to be even more stack allocation issues when building the benchmark suite with masking. We need to think about what to do because this is now a blocker. Any other thoughts @pzread, @MaheshRavishankar, @hanhanW?

pzread commented 1 year ago

Do we know the source of those stack allocations? Are they unbounded allocation or fixed size but too large? If they are fixed size but too large, then that's probably what we want to avoid and we need to look into how to eliminate them.

I'm quite curious about why vector masking will generate more stack allocations.

hanhanW commented 1 year ago

I guess there are some limitations that prevent that from doing so but I'm pretty surprise that we can get things moving only with stack allocations.

I think IREE dose not want heap allocation in codegen because not every device target has a big heap. @benvanik knows more about it. If we really need a big heap buffer, we can maybe propagate it to HAL, and pass the buffer as one of inputs/outputs. @benvanik mentioned that it's doable and he might already build some pieces that we can reuse.

One question is -- why do we need a big buffer here? looking into the input IR and assuming that there is producer (which is a matmul) in the same dispatch.

matmul ... -> tensor<1x1x144xf32>
%731 = mhlo.maximum %730, %262 : tensor<1x1x1x144xf32>

Do we really need an extra buffer for codgen the dispatch? If so, why is it larger than 32 KB?

(I thought integrating vector masking would be an NFC step, and then we will extend the usage for getting better performance. but it seems not..?)

MaheshRavishankar commented 1 year ago

Could you post the IR to diagnose the problem better, but broadly what Hanhan says. We only allow limited stack allocations and no additional heap allocations within the dispatch. So for the dispatches we create, we should never need stack or heap allocations (or at least outside the bound). Mostly if you vectorize you shouldnt need any heap allocation, so I am not sure why masking is requiring more stack/heap allocation....

benvanik commented 1 year ago

+1 hanhan/mahesh! There are ways to get transient buffers but I've yet to see a case where it's actually needed in any of the kind of things we codegen (today) so when issues are encountered it's a good indicator that something in the pipeline is failing.

Hopefully useful background because it's been a year or two since this last came up:

The base assertion is that if you're doing a local workgroup worth of work you shouldn't need a global problem-sized block of allocated memory. If, however, you're just on the border of what's reasonable in the local scope (maybe this issue?) that's something that will need to factor in to distribution across all backends as they all have similar limits (shared memory sizes on GPUs, etc). An example of one way we could handle this is to setup the codegen translation pipeline to try to codegen with the given distribution, fail when too much local memory used, and retry with finer distributions until it succeeds. We will need this kind of behavior in some form anyway as when we are targeting a broad spectrum of Vulkan/Metal/etc devices we're likely to need to produce those variants: https://vulkan.gpuinfo.org/displaydevicelimit.php?name=maxComputeSharedMemorySize&platform=all (today we just hardcode the attributes but that's not something we can ship). That only works if we can get to a reasonable point where scaling the distribution scales the local memory use, though, and the best case would be if it's parametric. Parametric representations of this stuff ("given this workload and distribution how much memory do you need?" -> "given this workload and available memory how much distribution do you want?") would let us tune (certain) aspects of distribution at runtime based on the available hardware.

But the best thing to keep in mind both in the short- and long-term is to focus on minimizing workgroup local memory sizes as much as practically/algorithmically possible. It would be an interesting metric to track over time (alongside dispatch binary size and such) as it's usually only possible to see the derivative aspects in vendor tools (decreased utilization because higher shared resource contention/etc). Stare at the tools long enough and you get a healthy fear of shared memory when it's not absolutely required: image That is to say if someone introduces a change that adds even just a few KB of additional shared memory consumption (a few wide vectors) then occupancy will be halved so IMO any problems we have around additional large shared memory allocations being produced by codegen should be P1 performance issues. This is really critical on GPU but also important on CPU: our "shared memory" is CPU data caches and if we spill out of that we risk introducing cache contention and thus lowering occupancy.

dcaballe commented 1 year ago

Thank you all for the detailed feedback and thank you, Che Yu, for giving me some hits for the investigation! I fixed the problem. We generated a pad + matmul and the pad operation wasn't vectorized with my code whereas the matmul was. That created a temporary buffer between the two ops larger than the current 32KB upper limit.

I agree with the approach and principles described above. However, from a production perspective, I wonder if we are not prioritizing performance vs support/coverage. We may want the compiler to generate "slower" code that can run instead of bailing out compilation. Is the use of shared memory always mandatory? Are we able to fall back to global memory for cases where we exceed a specific shared memory threshold? Have we considered heap allocation or the use of smart allocators for some targets? (Sorry, lot of questions :)) I think these points may become more relevant as we invest more in datacenter support.

MaheshRavishankar commented 1 year ago

I agree with the approach and principles described above. However, from a production perspective, I wonder if we are not prioritizing performance vs support/coverage. We may want the compiler to generate "slower" code that can run instead of bailing out compilation. Is the use of shared memory always mandatory? Are we able to fall back to global memory for cases where we exceed a specific shared memory threshold? Have we considered heap allocation or the use of smart allocators for some targets? (Sorry, lot of questions :)) I think these points may become more relevant as we invest more in datacenter support.

All this should actually be factored into the tile size selection, and use this limit while deciding the tile sizes / configurations to chose.... Ben can probably explain the tradeoffs for the other questions

hanhanW commented 1 year ago

Coming from the other issue. Here is the IR before bufferization: https://gist.github.com/hanhanW/3dfd7d87a456f28319c6569e2d498cfd/raw

%18 = flow.dispatch.tensor.load %9, offsets = [%arg0], sizes = [%17], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%6} -> tensor<?xf32>
%21 = scf.for %arg1 = %c0 to %20 step %c4 iter_args(%arg2 = %18) -> (tensor<?xf32>) {
  ...
}
%22 = scf.for %arg1 = %20 to %17 step %c4 iter_args(%arg2 = %21) -> (tensor<?xf32>) {
  ...
}
flow.dispatch.tensor.store %22, %9, offsets = [%arg0], sizes = [%17], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%6}

@matthias-springer I'm wondering if we can teach bufferization to not allocate a stack buffer for this type of case? It looks like all the operations can reuse the destination buffer?

matthias-springer commented 1 year ago

@hanhanW These two changes should remove the alloc + copy: https://reviews.llvm.org/D140007 https://reviews.llvm.org/D140008

matthias-springer commented 1 year ago

@hanhanW I prepared #11550, which cherry-picks those commits and few others that I need (for a different purpose).

hanhanW commented 1 year ago

Thanks a lot for the quick fix!!

dcaballe commented 1 year ago

Awesome! Should we close this issue?

matthias-springer commented 1 year ago

There are still two test cases that were not fixed by #11550:

integrations/tensorflow/test/iree_tf_tests/math/llvmcpu__dynamic_dim_log_softmax.run:# TODO(#11196): Re-enable once the issue is resolved.
integrations/tensorflow/test/iree_tf_tests/math/llvmcpu__dynamic_dim_softmax.run:# TODO(#11196): Re-enable once the issue is resolved.
dcaballe commented 1 year ago

Hey @matthias-springer, what is the state of the the last two cases that were not fixed by #11550? Is there anything pending here?