iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.84k stars 614 forks source link

EliminateEmptyTensorsPass fails to eliminate empty.tensor for tensor access indexing math after the empty vs before #19025

Open nirvedhmeshram opened 1 week ago

nirvedhmeshram commented 1 week ago

In this working example we have

  %6 = scf.forall (%arg0, %arg1) in (16, 80) shared_outs(%arg2 = %3) -> (tensor<128x16x640x16xf32>) {
    %7 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg1)
    %8 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0)
    %9 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x32xf16>
    %10 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x32xf16>
    %11 = tensor.empty() : tensor<8x16x8x16xf32>
    %12 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %11) -> (tensor<8x16x8x16xf32>) {
            ...
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %15 into %arg5[%14, 0, %13, 0] [4, 16, 4, 16] [1, 1, 1, 1] : tensor<4x16x4x16xf32> into tensor<8x16x8x16xf32>
      }
    } {mapping = [#gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_0>]}
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %12 into %arg2[%8, 0, %7, 0] [8, 16, 8, 16] [1, 1, 1, 1] : tensor<8x16x8x16xf32> into tensor<128x16x640x16xf32>
    }
  } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}

This empty is correctly eliminated to

%extracted_slice = tensor.extract_slice %arg2[%8, 0, %7, 0] [8, 16, 8, 16] [1, 1, 1, 1] : tensor<128x16x640x16xf32> to tensor<8x16x8x16xf32>

However in this very similar case with the only difference being the affine calculations are done after the nested loop

  %6 = scf.forall (%arg0, %arg1) in (16, 80) shared_outs(%arg2 = %3) -> (tensor<128x16x640x16xf32>) {
    %7 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x32xf16>
    %8 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x32xf16>
    %9 = tensor.empty() : tensor<8x16x8x16xf32>
    %10 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %9) -> (tensor<8x16x8x16xf32>) {
       ...
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %15 into %arg5[%14, 0, %13, 0] [4, 16, 4, 16] [1, 1, 1, 1] : tensor<4x16x4x16xf32> into tensor<8x16x8x16xf32>
      }
    } {mapping = [#gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_0>]}
    %11 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0)
    %12 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg1)
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %10 into %arg2[%11, 0, %12, 0] [8, 16, 8, 16] [1, 1, 1, 1] : tensor<8x16x8x16xf32> into tensor<128x16x640x16xf32>
    }
  } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}

The empty survives the pass

Note that there are no other uses of the affine math except the parallel_insert_slice

nirvedhmeshram commented 1 week ago

@MaheshRavishankar just an FYI, @Max191 and I spent time looking at IR today that ultimately turned out to be this issue. I think we can always be cautious and place the affine math towards the start of the loop but this seems like an upstream bug to at least be aware of.

MaheshRavishankar commented 1 week ago

This is very strange. The placement of operations shouldn't matter. Definitely seems like an upstream bug

Max191 commented 1 week ago

I think I might know what is going wrong actually. In the second case, no insertion point can be found to create an extract_slice of %arg2. It needs indices for the slice, but the indices are defined below the loop. Since the indices are not in scope before the loop (where the tensor.empty() needs to be replaced), it cannot create the extract_slice, and so it fails.

@nirvedhmeshram I think the solution is to set the insertion point before the loop when computing the delinearized offsets for the parallel_insert_slice op in the collapse_shape propagation pattern.