iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.5k stars 557 forks source link

[SPIR-V] Avoid shifting by non-splat amounts #16913

Open kuhar opened 3 months ago

kuhar commented 3 months ago

On RDNA gpus, shift vector instructions are much faster when the shift amount is a splat constant. However, we seem to emit non-splat shifts for int4 matvec from LLama2:

Input

func.func @main() {
  %c64 = arith.constant 64 : index
  %c8192 = arith.constant 8192 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f16
  %2 = util.unfoldable_constant dense<1> : tensor<4096x32x128xi4>
  %3 = util.unfoldable_constant dense<1.0> : tensor<4096x32xf16>
  %4 = util.unfoldable_constant dense<1.0> : tensor<4096x32xf16>
  %5 = util.unfoldable_constant dense<1.0> : tensor<32x128xf16>

  %9 = tensor.empty() : tensor<4096xf16>
  %10 = tensor.empty() : tensor<4096x32x128xf16>
  %11 = linalg.fill ins(%cst : f16) outs(%9 : tensor<4096xf16>) -> tensor<4096xf16>
  %12 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                     affine_map<(d0, d1, d2) -> (d0, d1)>,
                     affine_map<(d0, d1, d2) -> (d0, d1)>,
                     affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]
  } ins(%2, %3, %4 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>)
    outs(%10 : tensor<4096x32x128xf16>) {
  ^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
    %14 = arith.extui %in : i4 to i32
    %15 = arith.uitofp %14 : i32 to f16
    %16 = arith.subf %15, %in_1 : f16
    %17 = arith.mulf %16, %in_0 : f16
    linalg.yield %17 : f16
  } -> tensor<4096x32x128xf16>
  %13 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
                     affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                     affine_map<(d0, d1, d2) -> (d0)>],
    iterator_types = ["parallel", "reduction", "reduction"]
  } ins(%5, %12 : tensor<32x128xf16>, tensor<4096x32x128xf16>)
    outs(%11 : tensor<4096xf16>) {
  ^bb0(%in: f16, %in_0: f16, %out: f16):
    %14 = arith.mulf %in, %in_0 : f16
    %15 = arith.addf %14, %out : f16
    linalg.yield %15 : f16
  } -> tensor<4096xf16>

  check.expect_eq_const(%13, dense<4096.0> : tensor<4096xf16>) : tensor<4096xf16>
  return
}

Compile command:

tools/iree-compile vmt_int4.mlir \
  --iree-hal-target-backends=vulkan-spirv \
  --iree-vulkan-target-triple=rdna3-7900-linux \
  --iree-hal-dump-executable-files-to=dumps \
  -o vmt_int4.vmfb

The dumps directory contains IR at the level of the spir-v dialect (dumps/module__main_dispatch_0__main_dispatch_0_generic_4096x32x128_f16.spirv.mlir), with the following shift ops:

      %cst_vec_4xi8 = spirv.Constant dense<[0, 4, 0, 4]> : vector<4xi8>
      // ...
      %44 = spirv.VectorShuffle [0 : i32, 1 : i32] %27, %27 : vector<4xi8>, vector<4xi8> -> vector<2xi8>
      %45 = spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32] %44, %44 : vector<2xi8>, vector<2xi8> -> vector<4xi8>
      %46 = spirv.BitwiseAnd %45, %cst_vec_4xi8_0 : vector<4xi8>
      %47 = spirv.ShiftRightLogical %46, %cst_vec_4xi8 : vector<4xi8>, vector<4xi8>

We should figure out which pass produces this instruction sequence (--mlir-print-ir-after-all) and change it to avoid shifting by non-splat amounts.

kuhar commented 3 months ago

cc: @inbelic

benvanik commented 3 months ago

Dynamic shift amounts can be slow on many ISAs as they are often implemented as microcoded loops - would be interested to know if we hit this on other targets due to this being something that comes from higher up on the stack! /cc @hanhanW @antiagainst

inbelic commented 3 months ago

From initial inspection, these operations are a result of the following rewrite pattern: BitCastRewriter::genericRewrite and implemented here.

In this case we are transforming a vector.bitcast into a sequence of vector.shuffle, arith.andi and arith.rshui operations. So we could potentially rewrite this pass to use the splat vectors for shifts as mentioned. Or, we would it also be possible to prevent this pass from expanding the vector.bitcast to allow for a direct lowering to spirv.bitcast. WDYT?

kuhar commented 3 months ago

From initial inspection, these operations are a result of the following rewrite pattern: BitCastRewriter::genericRewrite and implemented here.

IIRC @nicolasvasilache wrote this based on the CPU performance work at the time. This lowers the bitcasts in a principled way, especially for more complicated cases like with i5, but this expansion is suboptimal for GPU. I don't think we can rely on spirv.Bitcast because the vector element type is i4 IIUC, which needs to be expanded to an integer type supported by the target..

I think the solution would be to change the order of operations to avoid non-splat shifts, i.e., so that this turns this into something like:

      %cst_vec_2xi8 = spirv.Constant dense<[4, 4]> : vector<2xi8>
      %46 = spirv.BitwiseAnd %27, %cst_vec_4xi8_0 : vector<4xi8>
      %44 = spirv.VectorShuffle [0 : i32, 1 : i32] %46, %46 : vector<4xi8>, vector<4xi8> -> vector<2xi8>
      %44_s = spirv.ShiftRightLogical %44, %cst_vec_2xi8 : vector<2xi8>, vector<2xi8>
      %45 = spirv.VectorShuffle [0 : i32, 4 : i32, 1 : i32, 5 : i32] %44, %44_s : vector<2xi8>, vector<2xi8> -> vector<4xi8>