Closed pzread closed 2 months ago
As chatted on Discord, we just need to add a vector to spirv pattern to handle such size-1 shape casts--they are no-op when translating to spirv given that both converts to scalar values. Assigning to @kuhar to fix it.
I'm actually more than happy to try to fix this : ) Also assign to myself
Cool, that's great! Feel free to ping @kuhar or me if you have questions then! :)
It looks like there is a size regression with drop unit dims on vector transfer + convert trivial shape_cast to no-op
https://github.com/openxla/iree/pull/14220#issuecomment-1606270765
I'll investigate it further
The problem is at VectorReductionToGPU
, the patterns in mlir::vector::populatePropagateWarpVectorDistributionPatterns
can't handle the vector.shape_cast
, which results in bad warp distribution:
----- After VectorReduceToGPU -----
func.func @main_dispatch_84_generic_2x256_i8xi32() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant dense<0.0125187514> : vector<1xf32>
%cst_1 = arith.constant dense<-1.000000e+00> : vector<1xf32>
%cst_2 = arith.constant dense<127> : vector<1xi32>
%cst_3 = arith.constant dense<-1> : vector<1xi32>
%cst_4 = arith.constant dense<39> : vector<1xi8>
%cst_5 = arith.constant dense<-128> : vector<1xi32>
%cst_6 = arith.constant dense<[16267, -17079]> : tensor<2xi32>
%c20224 = arith.constant 20224 : index
%c0_i8 = arith.constant 0 : i8
%cst_7 = arith.constant dense<[1196100044, 1139971180]> : tensor<2xi32>
%c0_i32 = arith.constant 0 : i32
%c64 = arith.constant 64 : index
%0 = gpu.thread_id x
%1 = bufferization.to_memref %cst_6 : memref<2xi32>
%2 = bufferization.to_memref %cst_7 : memref<2xi32>
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c20224) flags(ReadOnly) : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %3, 64 : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>
%4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<2xi32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %4, 64 : memref<2xi32, #hal.descriptor_type<storage_buffer>>
%5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c64) : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %5, 64 : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%6 = arith.cmpi eq, %0, %c0 : index
%alloc = memref.alloc() : memref<f32, #gpu.address_space<workgroup>>
scf.if %6 {
%9 = vector.transfer_read %3[%c0, %workgroup_id_x], %c0_i8 {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>, vector<256xi8>
%10 = arith.extsi %9 : vector<256xi8> to vector<256xi32>
%11 = vector.reduction <add>, %10, %c0_i32 : vector<256xi32> into i32
%12 = vector.broadcast %11 : i32 to vector<1xi32>
%13 = vector.transfer_read %1[%workgroup_id_x], %c0_i32 : memref<2xi32>, vector<i32>
%14 = vector.shape_cast %13 : vector<i32> to vector<1xi32>
%15 = vector.transfer_read %4[%workgroup_id_x], %c0_i32 : memref<2xi32, #hal.descriptor_type<storage_buffer>>, vector<i32>
%16 = vector.shape_cast %15 : vector<i32> to vector<1xi32>
%17 = vector.transfer_read %2[%workgroup_id_x], %c0_i32 : memref<2xi32>, vector<i32>
%18 = vector.shape_cast %17 : vector<i32> to vector<1xi32>
%19 = arith.muli %12, %cst_5 : vector<1xi32>
%20 = arith.subi %16, %19 : vector<1xi32>
%21 = arith.addi %14, %20 : vector<1xi32>
%22 = "tosa.apply_scale"(%21, %18, %cst_4) <{double_round = true}> : (vector<1xi32>, vector<1xi32>, vector<1xi8>) -> vector<1xi32>
%23 = arith.addi %22, %cst_3 : vector<1xi32>
%24 = arith.cmpi slt, %23, %cst_5 : vector<1xi32>
%25 = arith.select %24, %cst_5, %23 : vector<1xi1>, vector<1xi32>
%26 = arith.cmpi sgt, %23, %cst_2 : vector<1xi32>
%27 = arith.select %26, %cst_2, %25 : vector<1xi1>, vector<1xi32>
%28 = arith.trunci %27 : vector<1xi32> to vector<1xi8>
%29 = arith.sitofp %28 : vector<1xi8> to vector<1xf32>
%30 = arith.subf %29, %cst_1 : vector<1xf32>
%31 = arith.mulf %30, %cst_0 : vector<1xf32>
%32 = vector.shape_cast %31 : vector<1xf32> to vector<f32>
vector.transfer_write %32, %alloc[] : vector<f32>, memref<f32, #gpu.address_space<workgroup>>
}
gpu.barrier
%7 = vector.transfer_read %alloc[], %cst : memref<f32, #gpu.address_space<workgroup>>, vector<f32>
%8 = arith.cmpi eq, %0, %c0 : index
scf.if %8 {
vector.transfer_write %7, %5[%workgroup_id_x] : vector<f32>, memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
}
return
}
I'll send another patch to handle that.
The patch trying to fix vector distribution is sent out for review https://reviews.llvm.org/D154870
Unassigned myself as I'm not working on this now
Closing this for now -- no actions planned
OptimizeVectorTransferPass
used in the SPIR-V pipeline might generatevector.shape_cast
during optimizations. For example:https://github.com/openxla/iree/blob/6c016cac1c94ddf72314ae85053142f4b9babf87/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp#L102-L104
But ConvertToSPIRVPass calls populateVectorToSPIRVPatterns during lowering, which doesn't handle
vector.shape_cast
We run into this issue when trying to drop unit dims for vector transfer in
OptimizeVectorTransferPass
(#13340), which createsvector.shape_cast
to drop the unit dims on vectors.Here is an example after
OptimizeVectorTransferPass
with dropping unit dims from #13340 andvector.shape_cast
is generated during the optimization.Reproduce
To reproduce, I dumped the IR containing
vector.shape_cast
just beforeConvertToSPIRV
: https://gist.github.com/pzread/48352affa3aa5255d285f81091e1ece8