iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.58k stars 581 forks source link

[CPU][DT] i8.i8.i32 mmt4d codegen kernel is 10x slower than ukernels #15611

Open hanhanW opened 10 months ago

hanhanW commented 10 months ago

This is mainly because we are not using VNNI in codegen. The test case is extracted from MobileBert_int8 model.

To repro,

  1. Compile the mmt4d kernel with codegen, run iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/repro.mlir -o /tmp/a.vmfb
  2. Compile the mmt4d kernel with ukernels, adds --iree-llvmcpu-enable-ukernels=all to iree-compile.
  3. Run iree-benchmark-module --device=local-sync --module=/tmp/a.vmfb --function=foo --input=24x256x16x2xi8 --input=8x256x16x2xi8
func.func @foo(%arg0: tensor<24x256x16x2xi8>, %arg1: tensor<8x256x16x2xi8>) -> tensor<24x8x16x16xi32> {
  %c0_i32 = arith.constant 0 : i32
  %0 = tensor.empty() : tensor<24x8x16x16xi32>
  %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<24x8x16x16xi32>) -> tensor<24x8x16x16xi32>
  %2 = linalg.mmt4d ins(%arg0, %arg1 : tensor<24x256x16x2xi8>, tensor<8x256x16x2xi8>) outs(%1 : tensor<24x8x16x16xi32>) -> tensor<24x8x16x16xi32>
  return %2 : tensor<24x8x16x16xi32>
}

CodeGen performs 3 ms on my n2-standard-32 VM and the ukernel performs 0.311 ms.

dcaballe commented 10 months ago

Thanks, Hanhan! @qcolombet, this is the example I mentioned about VNNI.

qcolombet commented 10 months ago

Looking.

FWIW, my desktop is not cascadelake and running the VNNI ukernels is not possible (Illegal Instruction) and I get with --iree-llvmcpu-target-cpu=host:

qcolombet commented 10 months ago

TL;DR At first glance, I don't think this is going to be an easy fix. I'll have to look closer at the IREE pipeline to see if we can expose the interesting intrinsic while we still have the widening semantic easily accessible in the IR.

I believe that the code that reaches the backend is carrying a semantic that is too hard to untangle to extract the relevant pattern. Essentially, what reaches the backend tells it to produce correct i32 operations from the i8 elements (e.g., mul (sext i8 to i32)). The complicated part is that the IR cannot express what we need to easily expose the relevant semantic to the backend. In a nutshell we are missing the widening (e.g., i32 = mul i16, i16). So currently, the backend would need to rebuild that information from potentially complicated IR.

Note: The pattern we want here is widening followed by reduction (the reduction part being i32 horizontal_add <2x32>, <2xi32>). The reduction is already available with generic constructs (see llvm.vector.reduce.add).

The ukernel gets away with this representation gap, because it hides the real type (bitcast <32 x i16> to <16 x i32>), then uses the intrinsic that does both the element-wise mul with widening and the reduction in one go (in this case vpdpwssd). Without using the intrinsic early in the IREE pipeline (i.e., when we still know this is the semantic we want), we can't fake the type because the regular mul operation wouldn't have the semantic we want.

qcolombet commented 10 months ago

Looked a little bit more in the backend and there is a -convoluted- way of catching this kind of VNNI instructions. The trick is to express the computations on the odd/even lanes separately and then zip them back together.

The actual pattern to produce is fairly hairy and probably brittle. At this point it is probably better to just emit the related intrinsic than producing this complicated sequence that may or may not be lowered to what we want.

Here is what we match for PMADDUBSW (look for detectPMADDUBSW in https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/lib/Target/X86/X86ISelLowering.cpp):

// (i16 (ssat (add (mul (zext (even elts (i8 A))), (sext (even elts (i8 B)))),
//                 (mul (zext (odd elts (i8 A)), (sext (odd elts (i8 B))))))))
hanhanW commented 10 months ago

At this point it is probably better to just emit the related intrinsic than producing this complicated sequence that may or may not be lowered to what we want.

We can do it through VectorContractCustomKernelsPass. We are able to use inline assembly or intrinsics for some vector.contract ops. With some matchers, we should be able to generate VNNI instructions.

dcaballe commented 10 months ago

Thanks for looking into this, Quentin! It's good to know that at least we have some patterns...

I believe that the code that reaches the backend is carrying a semantic that is too hard to untangle to extract the relevant pattern. Essentially, what reaches the backend tells it to produce correct i32 operations from the i8 elements (e.g., mul (sext i8 to i32)). The complicated part is that the IR cannot express what we need to easily expose the relevant semantic to the backend.

Up to here the problem seems similar to the RVV widening instructions. Interesting enough, they only support 2x widening so something like i8 -> i32 is not directly supported. However, the backend is able to match the 4x cases and generate the corresponding sext + windeing instructions.

The trick is to express the computations on the odd/even lanes separately and then zip them back together.

Yeah, this sounds tricky...

At this point it is probably better to just emit the related intrinsic than producing this complicated sequence that may or may not be lowered to what we want.

We can do it through VectorContractCustomKernelsPass. We are able to use inline assembly or intrinsics for some vector.contract ops. With some matchers, we should be able to generate VNNI instructions.

I think to handle this properly we may want to introduce a vnni dialect and implement a simple lowering to the few VNNI ops of interest. We just need to find cycles for it :)

qcolombet commented 9 months ago

At this point it is probably better to just emit the related intrinsic than producing this complicated sequence that may or may not be lowered to what we want.

We can do it through VectorContractCustomKernelsPass. We are able to use inline assembly or intrinsics for some vector.contract ops. With some matchers, we should be able to generate VNNI instructions.

I don't see that pass being used but maybe I'm missing something.

What I see is populateVectorContractCustomKernelsPatterns is used in LLVMCPUMmt4dVectorLoweringPass, so that could be a way to get that. (I.e., same code base, but different entry point.) We may be able to lower vector.contract to vnni-like dialect or intrinsics there instead of matching an assembly kernel.

qcolombet commented 9 months ago

After looking at LLVMCPUMmt4dVectorLoweringPass a little bit, I believe that although we could make it work, I'm wondering if the lowering to vector.contract even makes sense. I feel it would be easier to massage the loops (tiling) to expose linalg.matmuls of the sizes that maps well to the hardware. Admittedly I am new to the whole pipeline in IREE so let me know if that doesn't make sense/go against the lowering currently done in IREE.