iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.75k stars 599 forks source link

[spirv] vector.shape_cast is not handled in ConvertToSPIRV #14191

Closed pzread closed 2 months ago

pzread commented 1 year ago

OptimizeVectorTransferPass used in the SPIR-V pipeline might generate vector.shape_cast during optimizations. For example:

https://github.com/openxla/iree/blob/6c016cac1c94ddf72314ae85053142f4b9babf87/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp#L102-L104

But ConvertToSPIRVPass calls populateVectorToSPIRVPatterns during lowering, which doesn't handle vector.shape_cast

We run into this issue when trying to drop unit dims for vector transfer in OptimizeVectorTransferPass (#13340), which creates vector.shape_cast to drop the unit dims on vectors.

Here is an example after OptimizeVectorTransferPass with dropping unit dims from #13340 and vector.shape_cast is generated during the optimization.

func.func @main_dispatch_71_generic_2x256_i8xi32() {
  %c0_i32 = arith.constant 0 : i32
  %cst = arith.constant dense<[1196100044, 1139971180]> : tensor<2xi32>
  %c0 = arith.constant 0 : index
  %cst_0 = arith.constant dense<0> : vector<1xi32>
  %c0_i8 = arith.constant 0 : i8
  %cst_1 = arith.constant dense<-128> : vector<1xi32>
  %cst_2 = arith.constant dense<39> : vector<1xi8>
  %cst_3 = arith.constant dense<-1> : vector<1xi32>
  %cst_4 = arith.constant dense<127> : vector<1xi32>
  %cst_5 = arith.constant dense<-1.000000e+00> : vector<1xf32>
  %cst_6 = arith.constant dense<0.0125187514> : vector<1xf32>
  %c20224 = arith.constant 20224 : index
  %c64 = arith.constant 64 : index
  %cst_7 = arith.constant dense<[16267, -17079]> : tensor<2xi32>
  %0 = bufferization.to_memref %cst_7 : memref<2xi32>
  %1 = bufferization.to_memref %cst : memref<2xi32>
  %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c20224) flags(ReadOnly) : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %2, 64 : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>
  %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<2xi32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %3, 64 : memref<2xi32, #hal.descriptor_type<storage_buffer>>
  %4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c64) : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %4, 64 : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %subview = memref.subview %4[%workgroup_id_x] [1] [1] : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>> to memref<1xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %subview_8 = memref.subview %0[%workgroup_id_x] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: ?>>
  %subview_9 = memref.subview %3[%workgroup_id_x] [1] [1] : memref<2xi32, #hal.descriptor_type<storage_buffer>> to memref<1xi32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %subview_10 = memref.subview %2[0, %workgroup_id_x] [256, 1] [1, 1] : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>> to memref<256x1xi8, strided<[2, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %5 = vector.transfer_read %subview_10[%c0, %c0], %c0_i8 {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<256x1xi8, strided<[2, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<256xi8>
  %6 = arith.extsi %5 : vector<256xi8> to vector<256xi32>
  %7 = vector.broadcast %6 : vector<256xi32> to vector<1x256xi32>
  %8 = vector.multi_reduction <add>, %7, %cst_0 [1] : vector<1x256xi32> to vector<1xi32>
  %subview_11 = memref.subview %1[%workgroup_id_x] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: ?>>
  %subview_12 = memref.subview %subview_8[0] [1] [1] : memref<1xi32, strided<[1], offset: ?>> to memref<i32, strided<[], offset: ?>>
  %9 = vector.transfer_read %subview_12[], %c0_i32 : memref<i32, strided<[], offset: ?>>, vector<i32>
  %10 = vector.shape_cast %9 : vector<i32> to vector<1xi32>
  %subview_13 = memref.subview %subview_9[0] [1] [1] : memref<1xi32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<i32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %11 = vector.transfer_read %subview_13[], %c0_i32 : memref<i32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<i32>
  %12 = vector.shape_cast %11 : vector<i32> to vector<1xi32>
  %subview_14 = memref.subview %subview_11[0] [1] [1] : memref<1xi32, strided<[1], offset: ?>> to memref<i32, strided<[], offset: ?>>
  %13 = vector.transfer_read %subview_14[], %c0_i32 : memref<i32, strided<[], offset: ?>>, vector<i32>
  %14 = vector.shape_cast %13 : vector<i32> to vector<1xi32>
  %15 = arith.muli %8, %cst_1 : vector<1xi32>
  %16 = arith.subi %12, %15 : vector<1xi32>
  %17 = arith.addi %10, %16 : vector<1xi32>
  %18 = "tosa.apply_scale"(%17, %14, %cst_2) <{double_round = true}> : (vector<1xi32>, vector<1xi32>, vector<1xi8>) -> vector<1xi32>
  %19 = arith.addi %18, %cst_3 : vector<1xi32>
  %20 = arith.cmpi slt, %19, %cst_1 : vector<1xi32>
  %21 = arith.select %20, %cst_1, %19 : vector<1xi1>, vector<1xi32>
  %22 = arith.cmpi sgt, %19, %cst_4 : vector<1xi32>
  %23 = arith.select %22, %cst_4, %21 : vector<1xi1>, vector<1xi32>
  %24 = arith.trunci %23 : vector<1xi32> to vector<1xi8>
  %25 = arith.sitofp %24 : vector<1xi8> to vector<1xf32>
  %26 = arith.subf %25, %cst_5 : vector<1xf32>
  %27 = arith.mulf %26, %cst_6 : vector<1xf32>
  %subview_15 = memref.subview %subview[0] [1] [1] : memref<1xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %28 = vector.shape_cast %27 : vector<1xf32> to vector<f32>
  vector.transfer_write %28, %subview_15[] : vector<f32>, memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>
  return
}

Reproduce

To reproduce, I dumped the IR containing vector.shape_cast just before ConvertToSPIRV: https://gist.github.com/pzread/48352affa3aa5255d285f81091e1ece8

iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-convert-to-spirv))))' sample.mlir
antiagainst commented 1 year ago

As chatted on Discord, we just need to add a vector to spirv pattern to handle such size-1 shape casts--they are no-op when translating to spirv given that both converts to scalar values. Assigning to @kuhar to fix it.

pzread commented 1 year ago

I'm actually more than happy to try to fix this : ) Also assign to myself

antiagainst commented 1 year ago

Cool, that's great! Feel free to ping @kuhar or me if you have questions then! :)

pzread commented 1 year ago

It looks like there is a size regression with drop unit dims on vector transfer + convert trivial shape_cast to no-op

https://github.com/openxla/iree/pull/14220#issuecomment-1606270765

I'll investigate it further

pzread commented 1 year ago

The problem is at VectorReductionToGPU, the patterns in mlir::vector::populatePropagateWarpVectorDistributionPatterns can't handle the vector.shape_cast, which results in bad warp distribution:

----- After VectorReduceToGPU -----
func.func @main_dispatch_84_generic_2x256_i8xi32() {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant dense<0.0125187514> : vector<1xf32>
  %cst_1 = arith.constant dense<-1.000000e+00> : vector<1xf32>
  %cst_2 = arith.constant dense<127> : vector<1xi32>
  %cst_3 = arith.constant dense<-1> : vector<1xi32>
  %cst_4 = arith.constant dense<39> : vector<1xi8>
  %cst_5 = arith.constant dense<-128> : vector<1xi32>
  %cst_6 = arith.constant dense<[16267, -17079]> : tensor<2xi32>
  %c20224 = arith.constant 20224 : index
  %c0_i8 = arith.constant 0 : i8
  %cst_7 = arith.constant dense<[1196100044, 1139971180]> : tensor<2xi32>
  %c0_i32 = arith.constant 0 : i32
  %c64 = arith.constant 64 : index
  %0 = gpu.thread_id  x
  %1 = bufferization.to_memref %cst_6 : memref<2xi32>
  %2 = bufferization.to_memref %cst_7 : memref<2xi32>
  %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c20224) flags(ReadOnly) : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %3, 64 : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>
  %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<2xi32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %4, 64 : memref<2xi32, #hal.descriptor_type<storage_buffer>>
  %5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c64) : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %5, 64 : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %6 = arith.cmpi eq, %0, %c0 : index
  %alloc = memref.alloc() : memref<f32, #gpu.address_space<workgroup>>
  scf.if %6 {
    %9 = vector.transfer_read %3[%c0, %workgroup_id_x], %c0_i8 {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>, vector<256xi8>
    %10 = arith.extsi %9 : vector<256xi8> to vector<256xi32>
    %11 = vector.reduction <add>, %10, %c0_i32 : vector<256xi32> into i32
    %12 = vector.broadcast %11 : i32 to vector<1xi32>
    %13 = vector.transfer_read %1[%workgroup_id_x], %c0_i32 : memref<2xi32>, vector<i32>
    %14 = vector.shape_cast %13 : vector<i32> to vector<1xi32>
    %15 = vector.transfer_read %4[%workgroup_id_x], %c0_i32 : memref<2xi32, #hal.descriptor_type<storage_buffer>>, vector<i32>
    %16 = vector.shape_cast %15 : vector<i32> to vector<1xi32>
    %17 = vector.transfer_read %2[%workgroup_id_x], %c0_i32 : memref<2xi32>, vector<i32>
    %18 = vector.shape_cast %17 : vector<i32> to vector<1xi32>
    %19 = arith.muli %12, %cst_5 : vector<1xi32>
    %20 = arith.subi %16, %19 : vector<1xi32>
    %21 = arith.addi %14, %20 : vector<1xi32>
    %22 = "tosa.apply_scale"(%21, %18, %cst_4) <{double_round = true}> : (vector<1xi32>, vector<1xi32>, vector<1xi8>) -> vector<1xi32>
    %23 = arith.addi %22, %cst_3 : vector<1xi32>
    %24 = arith.cmpi slt, %23, %cst_5 : vector<1xi32>
    %25 = arith.select %24, %cst_5, %23 : vector<1xi1>, vector<1xi32>
    %26 = arith.cmpi sgt, %23, %cst_2 : vector<1xi32>
    %27 = arith.select %26, %cst_2, %25 : vector<1xi1>, vector<1xi32>
    %28 = arith.trunci %27 : vector<1xi32> to vector<1xi8>
    %29 = arith.sitofp %28 : vector<1xi8> to vector<1xf32>
    %30 = arith.subf %29, %cst_1 : vector<1xf32>
    %31 = arith.mulf %30, %cst_0 : vector<1xf32>
    %32 = vector.shape_cast %31 : vector<1xf32> to vector<f32>
    vector.transfer_write %32, %alloc[] : vector<f32>, memref<f32, #gpu.address_space<workgroup>>
  }
  gpu.barrier
  %7 = vector.transfer_read %alloc[], %cst : memref<f32, #gpu.address_space<workgroup>>, vector<f32>
  %8 = arith.cmpi eq, %0, %c0 : index
  scf.if %8 {
    vector.transfer_write %7, %5[%workgroup_id_x] : vector<f32>, memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
  }
  return
}

I'll send another patch to handle that.

pzread commented 1 year ago

The patch trying to fix vector distribution is sent out for review https://reviews.llvm.org/D154870

pzread commented 5 months ago

Unassigned myself as I'm not working on this now

antiagainst commented 2 months ago

Closing this for now -- no actions planned