Open Max191 opened 1 year ago
Here is the IR just before ConvertToLLVM
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @quantized_matmul_f32_dispatch_0_generic_1x1x11008x32x128_f32() {
%cst = arith.constant dense<0.000000e+00> : vector<1xf32>
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x1xf32>
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<22544384xi8>
memref.assume_alignment %0, 64 : memref<22544384xi8>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<11008x32x1xf32>
memref.assume_alignment %1, 64 : memref<11008x32x1xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<11008x32x1xf32>
memref.assume_alignment %2, 64 : memref<11008x32x1xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x1x32x128xf32>
memref.assume_alignment %3, 64 : memref<1x1x32x128xf32>
%4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : memref<1x1x11008xf32>
memref.assume_alignment %4, 64 : memref<1x1x11008xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
cf.br ^bb1(%c0 : index)
^bb1(%5: index): // 2 preds: ^bb0, ^bb6
%6 = arith.cmpi slt, %5, %c32 : index
cf.cond_br %6, ^bb2(%c0, %cst_0 : index, vector<1x1x1xf32>), ^bb7
^bb2(%7: index, %8: vector<1x1x1xf32>): // 2 preds: ^bb1, ^bb5
%9 = arith.cmpi slt, %7, %c32 : index
cf.cond_br %9, ^bb3(%c0, %8 : index, vector<1x1x1xf32>), ^bb6
^bb3(%10: index, %11: vector<1x1x1xf32>): // 2 preds: ^bb2, ^bb4
%12 = arith.cmpi slt, %10, %c128 : index
cf.cond_br %12, ^bb4, ^bb5
^bb4: // pred: ^bb3
%13 = vector.load %3[%c0, %c0, %7, %10] : memref<1x1x32x128xf32>, vector<64xf32>
%14 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 * 64 + s2 * 2048 + s3 * 65536 + s1 floordiv 2)>()[%7, %10, %5, %workgroup_id_x]
%15 = vector.load %0[%14] : memref<22544384xi8>, vector<32xi8>
%16 = vector.bitcast %15 : vector<32xi8> to vector<64xi4>
%17 = vector.broadcast %16 : vector<64xi4> to vector<1x1x64xi4>
%18 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 32)>()[%5, %workgroup_id_x]
%19 = memref.load %1[%18, %7, %c0] : memref<11008x32x1xf32>
%20 = vector.broadcast %19 : f32 to vector<1x1x1x1x64xf32>
%21 = memref.load %2[%18, %7, %c0] : memref<11008x32x1xf32>
%22 = vector.broadcast %21 : f32 to vector<1x1x1x1x64xf32>
%23 = arith.extui %17 : vector<1x1x64xi4> to vector<1x1x64xi32>
%24 = arith.uitofp %23 : vector<1x1x64xi32> to vector<1x1x64xf32>
%25 = vector.broadcast %24 : vector<1x1x64xf32> to vector<1x1x1x1x64xf32>
%26 = arith.subf %25, %22 : vector<1x1x1x1x64xf32>
%27 = arith.mulf %26, %20 : vector<1x1x1x1x64xf32>
%28 = vector.extract %11[0, 0, 0] : vector<1x1x1xf32>
%29 = vector.extract %27[0, 0, 0, 0] : vector<1x1x1x1x64xf32>
%30 = arith.mulf %13, %29 : vector<64xf32>
%31 = vector.reduction <add>, %30, %28 : vector<64xf32> into f32
%32 = vector.insert %31, %cst [0] : f32 into vector<1xf32>
%33 = vector.broadcast %32 : vector<1xf32> to vector<1x1x1xf32>
%34 = arith.addi %10, %c64 : index
cf.br ^bb3(%34, %33 : index, vector<1x1x1xf32>)
^bb5: // pred: ^bb3
%35 = arith.addi %7, %c1 : index
cf.br ^bb2(%35, %11 : index, vector<1x1x1xf32>)
^bb6: // pred: ^bb2
%36 = vector.extract %8[0, 0, 0] : vector<1x1x1xf32>
%37 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 32)>()[%5, %workgroup_id_x]
memref.store %36, %4[%c0, %c0, %37] : memref<1x1x11008xf32>
%38 = arith.addi %5, %c1 : index
cf.br ^bb1(%38 : index)
^bb7: // pred: ^bb1
return
}
}
It loads a half sized vector of i8
, and then uses vector.bitcast
to convert to i4
without moving data. However, the data eventually does have to move each i4
into its own single byte container, so perhaps this is the place where we should try to fix this issue.
One idea I have would be to use 2 vector.maskedload
ops here instead of
%15 = vector.load %0[%14] : memref<22544384xi8>, vector<32xi8>
and then rather than doing a vector.bitcast
, we can just keep the vector in i8
and the vector.extui
can simply change from i4->i32
to i8->i32
Does this seem like a reasonable approach?
I made a draft PR for a pass to rewrite the above bitcast sequence: https://github.com/openxla/iree/pull/14921
It should be generalized more still, but I'll post it here for discussion. It converts the sequence:
vector.load->vector.bitcast->vector.broadcast->arith.extui
into
vector.load->vector.bitcast->vector.shuffle->arith.and->arith.shrui->vector.broadcast
It brings the llama2 cpu latency down to 175ms from 420ms on my machine (these numbers include the custom tuning config I mentioned above)
What happened?
The following IR runs faster for
i8
thani4
:The IR dumps and disassemblies are uploaded here: https://drive.google.com/drive/folders/1UTd73C41xpoAHmDQKPSYfVfh_Lewz_4K?usp=sharing
Looking at the disassembly for
i4
, when thei4
weights are loaded, they need to be extracted and shifted into 1i4
per byte, which is all scalarized.Steps to reproduce your issue
What component(s) does this issue relate to?
No response
Version information
On this branch: https://github.com/Max191/iree/tree/llama2_int4_cpu This is based on top of fe8f79e930f708af2121b1241474d6229cb5593e with a custom tile size config from tuning. The same issue occurs on fe8f79e930f708af2121b1241474d6229cb5593e without the tuning config.
Additional context
No response