Open kuhar opened 3 months ago
cc: @inbelic
Dynamic shift amounts can be slow on many ISAs as they are often implemented as microcoded loops - would be interested to know if we hit this on other targets due to this being something that comes from higher up on the stack! /cc @hanhanW @antiagainst
From initial inspection, these operations are a result of the following rewrite pattern: BitCastRewriter::genericRewrite and implemented here.
In this case we are transforming a vector.bitcast
into a sequence of vector.shuffle
, arith.andi
and arith.rshui
operations. So we could potentially rewrite this pass to use the splat vectors for shifts as mentioned. Or, we would it also be possible to prevent this pass from expanding the vector.bitcast
to allow for a direct lowering to spirv.bitcast
. WDYT?
From initial inspection, these operations are a result of the following rewrite pattern: BitCastRewriter::genericRewrite and implemented here.
IIRC @nicolasvasilache wrote this based on the CPU performance work at the time. This lowers the bitcasts in a principled way, especially for more complicated cases like with i5, but this expansion is suboptimal for GPU. I don't think we can rely on spirv.Bitcast
because the vector element type is i4 IIUC, which needs to be expanded to an integer type supported by the target..
I think the solution would be to change the order of operations to avoid non-splat shifts, i.e., so that this turns this into something like:
%cst_vec_2xi8 = spirv.Constant dense<[4, 4]> : vector<2xi8>
%46 = spirv.BitwiseAnd %27, %cst_vec_4xi8_0 : vector<4xi8>
%44 = spirv.VectorShuffle [0 : i32, 1 : i32] %46, %46 : vector<4xi8>, vector<4xi8> -> vector<2xi8>
%44_s = spirv.ShiftRightLogical %44, %cst_vec_2xi8 : vector<2xi8>, vector<2xi8>
%45 = spirv.VectorShuffle [0 : i32, 4 : i32, 1 : i32, 5 : i32] %44, %44_s : vector<2xi8>, vector<2xi8> -> vector<4xi8>
On RDNA gpus, shift vector instructions are much faster when the shift amount is a splat constant. However, we seem to emit non-splat shifts for int4 matvec from LLama2:
Input
Compile command:
The dumps directory contains IR at the level of the spir-v dialect (
dumps/module__main_dispatch_0__main_dispatch_0_generic_4096x32x128_f16.spirv.mlir
), with the following shift ops:We should figure out which pass produces this instruction sequence (
--mlir-print-ir-after-all
) and change it to avoid shifting by non-splat amounts.