iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.6k stars 583 forks source link

[CPU] Quantized matmul `i4` slower than `i8` #14914

Open Max191 opened 1 year ago

Max191 commented 1 year ago

What happened?

The following IR runs faster for i8 than i4:

builtin.module {
  func.func @quantized_matmul_f32(%arg0: tensor<11008x32x128xi4>, %arg1: tensor<11008x32x1xf32>, %arg2: tensor<11008x32x1xf32>, %arg3: tensor<1x1x32x128xf32>) -> tensor<1x1x11008xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %4 = tensor.empty() : tensor<1x1x11008xf32>
    %5 = tensor.empty() : tensor<11008x32x128xf32>
    %6 = linalg.fill ins(%cst : f32) outs(%4 : tensor<1x1x11008xf32>) -> tensor<1x1x11008xf32>
    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor<11008x32x128xi4>, tensor<11008x32x1xf32>, tensor<11008x32x1xf32>) outs(%5 : tensor<11008x32x128xf32>) {
    ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32):
      %9 = arith.extui %in : i4 to i32
      %10 = arith.uitofp %9 : i32 to f32
      %11 = arith.subf %10, %in_1 : f32
      %12 = arith.mulf %11, %in_0 : f32
      linalg.yield %12 : f32
    } -> tensor<11008x32x128xf32>
    %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg3, %7 : tensor<1x1x32x128xf32>, tensor<11008x32x128xf32>) outs(%6 : tensor<1x1x11008xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %9 = arith.mulf %in, %in_0 : f32
      %10 = arith.addf %9, %out : f32
      linalg.yield %10 : f32
    } -> tensor<1x1x11008xf32>
    return %8 : tensor<1x1x11008xf32>
  }
}

The IR dumps and disassemblies are uploaded here: https://drive.google.com/drive/folders/1UTd73C41xpoAHmDQKPSYfVfh_Lewz_4K?usp=sharing

Looking at the disassembly for i4, when the i4 weights are loaded, they need to be extracted and shifted into 1 i4 per byte, which is all scalarized.

Steps to reproduce your issue

  1. Compile command:
    iree-compile --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-enable-microkernels --iree-llvmcpu-stack-allocation-limit=256000 --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-vm-target-truncate-unsupported-floats --iree-codegen-check-ir-before-llvm-conversion=false --iree-opt-const-expr-hoisting=False --iree-llvmcpu-keep-linker-artifacts=false --iree-llvmcpu-link-embedded=false quantized_matmul_i4.mlir

What component(s) does this issue relate to?

No response

Version information

On this branch: https://github.com/Max191/iree/tree/llama2_int4_cpu This is based on top of fe8f79e930f708af2121b1241474d6229cb5593e with a custom tile size config from tuning. The same issue occurs on fe8f79e930f708af2121b1241474d6229cb5593e without the tuning config.

Additional context

No response

Max191 commented 1 year ago

Here is the IR just before ConvertToLLVM


// -----// IR Dump After CSE (cse) //----- //
module {
  func.func @quantized_matmul_f32_dispatch_0_generic_1x1x11008x32x128_f32() {
    %cst = arith.constant dense<0.000000e+00> : vector<1xf32>
    %cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x1xf32>
    %c128 = arith.constant 128 : index
    %c64 = arith.constant 64 : index
    %c32 = arith.constant 32 : index
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<22544384xi8>
    memref.assume_alignment %0, 64 : memref<22544384xi8>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<11008x32x1xf32>
    memref.assume_alignment %1, 64 : memref<11008x32x1xf32>
    %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<11008x32x1xf32>
    memref.assume_alignment %2, 64 : memref<11008x32x1xf32>
    %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x1x32x128xf32>
    memref.assume_alignment %3, 64 : memref<1x1x32x128xf32>
    %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : memref<1x1x11008xf32>
    memref.assume_alignment %4, 64 : memref<1x1x11008xf32>
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    cf.br ^bb1(%c0 : index)
  ^bb1(%5: index):  // 2 preds: ^bb0, ^bb6
    %6 = arith.cmpi slt, %5, %c32 : index
    cf.cond_br %6, ^bb2(%c0, %cst_0 : index, vector<1x1x1xf32>), ^bb7
  ^bb2(%7: index, %8: vector<1x1x1xf32>):  // 2 preds: ^bb1, ^bb5
    %9 = arith.cmpi slt, %7, %c32 : index
    cf.cond_br %9, ^bb3(%c0, %8 : index, vector<1x1x1xf32>), ^bb6
  ^bb3(%10: index, %11: vector<1x1x1xf32>):  // 2 preds: ^bb2, ^bb4
    %12 = arith.cmpi slt, %10, %c128 : index
    cf.cond_br %12, ^bb4, ^bb5
  ^bb4:  // pred: ^bb3
    %13 = vector.load %3[%c0, %c0, %7, %10] : memref<1x1x32x128xf32>, vector<64xf32>
    %14 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 * 64 + s2 * 2048 + s3 * 65536 + s1 floordiv 2)>()[%7, %10, %5, %workgroup_id_x]
    %15 = vector.load %0[%14] : memref<22544384xi8>, vector<32xi8>
    %16 = vector.bitcast %15 : vector<32xi8> to vector<64xi4>
    %17 = vector.broadcast %16 : vector<64xi4> to vector<1x1x64xi4>
    %18 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 32)>()[%5, %workgroup_id_x]
    %19 = memref.load %1[%18, %7, %c0] : memref<11008x32x1xf32>
    %20 = vector.broadcast %19 : f32 to vector<1x1x1x1x64xf32>
    %21 = memref.load %2[%18, %7, %c0] : memref<11008x32x1xf32>
    %22 = vector.broadcast %21 : f32 to vector<1x1x1x1x64xf32>
    %23 = arith.extui %17 : vector<1x1x64xi4> to vector<1x1x64xi32>
    %24 = arith.uitofp %23 : vector<1x1x64xi32> to vector<1x1x64xf32>
    %25 = vector.broadcast %24 : vector<1x1x64xf32> to vector<1x1x1x1x64xf32>
    %26 = arith.subf %25, %22 : vector<1x1x1x1x64xf32>
    %27 = arith.mulf %26, %20 : vector<1x1x1x1x64xf32>
    %28 = vector.extract %11[0, 0, 0] : vector<1x1x1xf32>
    %29 = vector.extract %27[0, 0, 0, 0] : vector<1x1x1x1x64xf32>
    %30 = arith.mulf %13, %29 : vector<64xf32>
    %31 = vector.reduction <add>, %30, %28 : vector<64xf32> into f32
    %32 = vector.insert %31, %cst [0] : f32 into vector<1xf32>
    %33 = vector.broadcast %32 : vector<1xf32> to vector<1x1x1xf32>
    %34 = arith.addi %10, %c64 : index
    cf.br ^bb3(%34, %33 : index, vector<1x1x1xf32>)
  ^bb5:  // pred: ^bb3
    %35 = arith.addi %7, %c1 : index
    cf.br ^bb2(%35, %11 : index, vector<1x1x1xf32>)
  ^bb6:  // pred: ^bb2
    %36 = vector.extract %8[0, 0, 0] : vector<1x1x1xf32>
    %37 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 32)>()[%5, %workgroup_id_x]
    memref.store %36, %4[%c0, %c0, %37] : memref<1x1x11008xf32>
    %38 = arith.addi %5, %c1 : index
    cf.br ^bb1(%38 : index)
  ^bb7:  // pred: ^bb1
    return
  }
}

It loads a half sized vector of i8, and then uses vector.bitcast to convert to i4 without moving data. However, the data eventually does have to move each i4 into its own single byte container, so perhaps this is the place where we should try to fix this issue.

One idea I have would be to use 2 vector.maskedload ops here instead of

%15 = vector.load %0[%14] : memref<22544384xi8>, vector<32xi8>

and then rather than doing a vector.bitcast, we can just keep the vector in i8 and the vector.extui can simply change from i4->i32 to i8->i32

Does this seem like a reasonable approach?

Max191 commented 1 year ago

I made a draft PR for a pass to rewrite the above bitcast sequence: https://github.com/openxla/iree/pull/14921

It should be generalized more still, but I'll post it here for discussion. It converts the sequence:

vector.load->vector.bitcast->vector.broadcast->arith.extui

into

vector.load->vector.bitcast->vector.shuffle->arith.and->arith.shrui->vector.broadcast

It brings the llama2 cpu latency down to 175ms from 420ms on my machine (these numbers include the custom tuning config I mentioned above)