iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.59k stars 580 forks source link

[CPU] Performance for quantized matvec 13x worse with fp16 than fp32 #14864

Open Max191 opened 1 year ago

Max191 commented 1 year ago

What happened?

The benchmarks of the following IRs are much slower for fp16 than fp32: fp16 IR:

builtin.module {
  func.func @quantized_matmul_f16(%arg0: tensor<11008x32x128xi8>, %arg1: tensor<11008x32x1xf16>, %arg2: tensor<11008x32x1xf16>, %arg3: tensor<1x1x32x128xf16>) -> tensor<1x1x11008xf16> {
    %cst = arith.constant 0.000000e+00 : f16
    %4 = tensor.empty() : tensor<1x1x11008xf16>
    %5 = tensor.empty() : tensor<11008x32x128xf16>
    %6 = linalg.fill ins(%cst : f16) outs(%4 : tensor<1x1x11008xf16>) -> tensor<1x1x11008xf16>
    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor<11008x32x128xi8>, tensor<11008x32x1xf16>, tensor<11008x32x1xf16>) outs(%5 : tensor<11008x32x128xf16>) {
    ^bb0(%in: i8, %in_0: f16, %in_1: f16, %out: f16):
      %9 = arith.extui %in : i8 to i32
      %10 = arith.uitofp %9 : i32 to f16
      %11 = arith.subf %10, %in_1 : f16
      %12 = arith.mulf %11, %in_0 : f16
      linalg.yield %12 : f16
    } -> tensor<11008x32x128xf16>
    %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg3, %7 : tensor<1x1x32x128xf16>, tensor<11008x32x128xf16>) outs(%6 : tensor<1x1x11008xf16>) {
    ^bb0(%in: f16, %in_0: f16, %out: f16):
      %9 = arith.mulf %in, %in_0 : f16
      %10 = arith.addf %9, %out : f16
      linalg.yield %10 : f16
    } -> tensor<1x1x11008xf16>
    return %8 : tensor<1x1x11008xf16>
  }
}

Benchmark:

Running iree-benchmark-module
Run on (24 X 5732.71 MHz CPU s)
CPU Caches:
  L1 Data 32 KiB (x12)
  L1 Instruction 32 KiB (x12)
  L2 Unified 1024 KiB (x12)
  L3 Unified 32768 KiB (x2)
Load Average: 0.62, 1.16, 0.62
***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
***WARNING*** Library was built as DEBUG. Timings may be affected.
---------------------------------------------------------------------------------------------------------
Benchmark                                               Time             CPU   Iterations UserCounters...
---------------------------------------------------------------------------------------------------------
BM_quantized_matmul_f16/process_time/real_time       7.91 ms         42.7 ms           89 items_per_second=126.375/s

fp32 IR:

builtin.module {
  func.func @quantized_matmul_f32(%arg0: tensor<11008x32x128xi8>, %arg1: tensor<11008x32x1xf32>, %arg2: tensor<11008x32x1xf32>, %arg3: tensor<1x1x32x128xf32>) -> tensor<1x1x11008xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %4 = tensor.empty() : tensor<1x1x11008xf32>
    %5 = tensor.empty() : tensor<11008x32x128xf32>
    %6 = linalg.fill ins(%cst : f32) outs(%4 : tensor<1x1x11008xf32>) -> tensor<1x1x11008xf32>
    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor<11008x32x128xi8>, tensor<11008x32x1xf32>, tensor<11008x32x1xf32>) outs(%5 : tensor<11008x32x128xf32>) {
    ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
      %9 = arith.extui %in : i8 to i32
      %10 = arith.uitofp %9 : i32 to f32
      %11 = arith.subf %10, %in_1 : f32
      %12 = arith.mulf %11, %in_0 : f32
      linalg.yield %12 : f32
    } -> tensor<11008x32x128xf32>
    %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg3, %7 : tensor<1x1x32x128xf32>, tensor<11008x32x128xf32>) outs(%6 : tensor<1x1x11008xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %9 = arith.mulf %in, %in_0 : f32
      %10 = arith.addf %9, %out : f32
      linalg.yield %10 : f32
    } -> tensor<1x1x11008xf32>
    return %8 : tensor<1x1x11008xf32>
  }
}

Benchmark:

Running iree-benchmark-module
Run on (24 X 5732.71 MHz CPU s)
CPU Caches:
  L1 Data 32 KiB (x12)
  L1 Instruction 32 KiB (x12)
  L2 Unified 1024 KiB (x12)
  L3 Unified 32768 KiB (x2)
Load Average: 0.96, 1.22, 0.64
***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
***WARNING*** Library was built as DEBUG. Timings may be affected.
---------------------------------------------------------------------------------------------------------
Benchmark                                               Time             CPU   Iterations UserCounters...
---------------------------------------------------------------------------------------------------------
BM_quantized_matmul_f32/process_time/real_time      0.604 ms         3.19 ms          997 items_per_second=1.65679k/s

Steps to reproduce your issue

  1. On 40794933d45fdbb05d631c9612dc91cc343d1efe
  2. Compile command:
    iree-compile --iree-input-type=tm_tensor --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-enable-microkernels --iree-llvmcpu-stack-allocation-limit=256000 --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-vm-target-truncate-unsupported-floats --iree-codegen-check-ir-before-llvm-conversion=false --iree-opt-const-expr-hoisting=False quantized_matmul_f32.mlir -o quantized_matmul_f32.vmfb
  3. Run command:
    iree-benchmark-module --module=quantized_matmul_f32.vmfb --device=local-task --function=quantized_matmul_f32 --input=11008x32x128xi8 --input=11008x32x1xf32 --input=11008x32x1xf32 --input=1x1x32x128xf32

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

Max191 commented 1 year ago

I am now realizing that I do not have avx512fp16 in my iree-cpuinfo. Could this cause such a large performance difference?

sse3                 1
ssse3                1
sse4.1               1
sse4.2               1
sse4a                1
avx                  1
fma                  1
fma4                 0
xop                  0
f16c                 1
avx2                 1
avx512f              1
avx512cd             1
avx512vl             1
avx512dq             1
avx512bw             1
avx512ifma           1
avx512vbmi           1
avx512vpopcntdq      1
avx512vnni           1
avx512vbmi2          1
avx512bitalg         1
avx512bf16           1
avx512fp16           0
amx-tile             0
amx-int8             0
amx-bf16             0

It seems to me from the IR dumps that the two IRs are basically the same other than the fp16 vs fp32 even through the LLVM codegen

bjacob commented 1 year ago

I am now realizing that I do not have avx512fp16 in my iree-cpuinfo. Could this cause such a large performance difference?

Indeed, like almost every x86 CPU in existence, this does not have native f16 arithmetic support, meaning that f16 arithmetic has to be implemented as casting from f16 to f32, then doing the arithmetic in f32. Like you found, the x86 extension that would enable native f16 arithmetic is avx512fp16, and that is only supported in a few very recent Intel microarchitectures, and no AMD microarchitecture at all at the moment. Per https://en.wikipedia.org/wiki/AVX-512#CPUs_with_AVX-512, check out the big table, scroll right to see FP16 on the very last column, it is only available in Intel Sapphire Rapids (2023).

It's worth noting that the equivalent feature for the bfloat16 type, avx512bf16, is a bit more widely available. In particular, it is available in AMD Zen4. It would not be hard to add some microkernels targeting that, I've been meaning to do so.

Also, on the arm64 architecture, support for f16 and bf16 arithmetic is more widespread.

Anyway, what the above says is that, unless targeting Sapphire Rapids or some non-x86, it is absolutely expected that f16 performs worse than f32.

But it should only be performing slightly worse, because the overhead of converting between f16 and f32 should be small, because all current x86 CPUs support the f16c extension, see your iree-cpuinfo log above.

So, my above explanation would explain f16 performing somewhere between 50% and 90% of f32. It does not explain f16 performing 12x slower than f32.

Maybe rename this Issue to make it clear that what is noteworthy here is not f16 being worse than f32, it's the specific large factor by which it is.

bjacob commented 1 year ago

Next steps here:

The payload consists of two linalg.generic ops - the first one dequantizes i8 to float, the second one is the matvec. The immediate next step is to split that to study both parts separately.

Is the 13x speed difference more specifically on one of the two linalg.generic?

For each, compile with --iree-llvm-keep-linker-artifacts and --iree-llvm-link-embedded=false and get a disassembly of the resulting .so file using objdump -d -Mintel.

From past experience, it's expected that the codegen for the f16 matvec is sub-optimal, even when the matvec is written as a linalg.matmul; that's one of the reasons why I added f16 matmul ukernels. But for those to help here, two things need to happen:

bjacob commented 1 year ago

Also, since the testcase here involves i8->f16 dequantization, taking a step back, the real optimal outcome here would be to just avoid dequantization. Staying on the i8 path will result in higher performance than any float path on most CPUs.

Max191 commented 1 year ago

One of the issues here is that the dequantization linalg.generic and the matvec linalg.generic get fused and then rematerialized into a single linalg.generic:

// -----// IR Dump After RematerializeParallelOps (iree-codegen-rematerialize-parallel-ops) //----- //
func.func @quantized_matmul_f16_dispatch_0_generic_1x1x11008x32x128_f16() {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f16
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi8>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32x1xf16>>
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32x1xf16>>
  %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x1x32x128xf16>>
  %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x1x11008xf16>>
  %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [11008, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi8>> -> tensor<11008x32x128xi8>
  %6 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [11008, 32, 1], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32x1xf16>> -> tensor<11008x32x1xf16>
  %7 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [11008, 32, 1], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32x1xf16>> -> tensor<11008x32x1xf16>
  %8 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0], sizes = [1, 1, 32, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x32x128xf16>> -> tensor<1x1x32x128xf16>
  %9 = tensor.empty() : tensor<1x1x11008xf16>
  %10 = linalg.fill ins(%cst : f16) outs(%9 : tensor<1x1x11008xf16>) -> tensor<1x1x11008xf16>
  %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, 0)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, 0)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%8, %5, %6, %7 : tensor<1x1x32x128xf16>, tensor<11008x32x128xi8>, tensor<11008x32x1xf16>, tensor<11008x32x1xf16>) outs(%10 : tensor<1x1x11008xf16>) {
  ^bb0(%in: f16, %in_0: i8, %in_1: f16, %in_2: f16, %out: f16):
    %12 = arith.extui %in_0 : i8 to i16
    %13 = arith.uitofp %12 : i16 to f16
    %14 = arith.subf %13, %in_2 : f16
    %15 = arith.mulf %14, %in_1 : f16
    %16 = arith.mulf %in, %15 : f16
    %17 = arith.addf %16, %out : f16
    linalg.yield %17 : f16
  } -> tensor<1x1x11008xf16>
  flow.dispatch.tensor.store %11, %4, offsets = [0, 0, 0], sizes = [1, 1, 11008], strides = [1, 1, 1] : tensor<1x1x11008xf16> -> !flow.dispatch.tensor<writeonly:tensor<1x1x11008xf16>>
  return
}

We do this because we want these linalg.generic ops to end up in the same dispatch, and without rematerialization, we were restricted to small tile sizes on the dequantizaton, which significantly hurt the performance. So, I don't think raising the matvec will be possible. I can still try splitting the IR and seeing which linalg.generic sees most of the difference. That may indicate what part is causing problems in the rematerialized op.

Another thing that I was beginning a while ago before being pulled away to SPIRV was making a new ukernel for this specific rematerialized dequantization + matvec. I hadn't made much progress on that, but perhaps we could go back to adding that ukernel now.

Max191 commented 1 year ago

Also, since the testcase here involves i8->f16 dequantization, taking a step back, the real optimal outcome here would be to just avoid dequantization. Staying on the i8 path will result in higher performance than any float path on most CPUs.

I am actually only running in i8 at the moment so I can give the quantized weights as function arguments and actually benchmark in iree-benchmark-module, but eventually we will want to be using i4. Although, the point still stands to extend to i8 rather than f16, as per the discussion on i4 correctness

bjacob commented 1 year ago

I am actually only running in i8 at the moment so I can give the quantized weights as function arguments and actually benchmark in iree-benchmark-module, but eventually we will want to be using i4. Although, the point still stands to extend to i8 rather than f16, as per the discussion on i4 correctness.

Thanks for confirming --- I was suspecting that this was all part of the same overarching project, but didn't have context.

Yes. The trouble you're running into here is more evidence that we need to handle i4 by extending to i8, avoiding dequantization.

The kind of things that we are looking at here are feasible but not small enough to be the kind of things that we do as an interim solution while we wait to switch away from dequantization. Adding a custom ukernel for a fusion is not really something that we'd do in this kind of short-term interim timeframe, and the alternative, to look into a SIMD codegen performance issue for a specific fusion, is probably harder still. Both are things we'd like to only do for workloads that are here to stay, not interim ones.

I do understand and agree though that to the extent that this is the workload that we have to optimize, the fusion does make sense and is needed to unlock optimal performance. This is because the dequantization here is actually on the matrix side, not on the vector side, of the matvec, so the output of the dequantization arithmetic is traversed exactly once (this would be different if we were dequantizing the vector, or if we were dealing with a more general matmul, traversing its inputs multiple times, instead of a matvec). But then, like you note, we are either embarking on writing a new ukernel for this fusion, or foregoing the ability to use ukernels and have to fix the SIMD codegen in the compiler; neither is to be undertaken lightly :-)

If you want a pragmatic approach to unblock improving performance here without this becoming a large projects, you could forego the fusion. While the fusion makes sense, it results in a large project, so settle for something easier to achieve here, essentially what I outlined in my above comment: rewrite the second linalg.generic as a linalg.matmul, pass --iree-flow-enable-data-tiling, make sure you get a linalg.mmt4d op, make sure it's not fused and is matched as a mmt4d ukernel, then check/work with me regarding adding narrow/matvec ukernel code paths.

stellaraccident commented 1 year ago

FTR -- I think we're ok taking on an ambitious project if we believe it is the path to getting the best performance.

What isn't clear to me in this case is how the quantization would need to change or be specified to get us into a pure int world.

One counterpoint: this style of fused weight dequantization is having a bit of a moment in the field right now, and it isn't clear to me that it isn't something we want to be performant, even given its flaws and the potential for something better.

Max191 commented 1 year ago

To give a bit more context, the target number for performance is currently around 100ms for int4, while the number we are getting now with llama2 is around 1000ms.

On an older vicuna model (discussion here https://github.com/openxla/iree/issues/14337) we got the performance down to ~160ms for int8, but that model was in fp32. In the work for the previous vicuna model, we determined that the performance hit from not fusing the dequantization and matvec would prevent us from hitting that target.

bjacob commented 1 year ago

@stellaraccident , OK then. If we want to go down this path, let's first take a close look at the disassembly of the generated code (command line flags as in above comment), and a --mlir-print-ir-after-all to get a sense of how we arrive to that generated code. Then maybe we can find something glaring that we can fix and get the codegen in a better shape; after all, the good thing when perf is 13x off is that there's probably something obvious to improve on.

I'm saying that, this being the ask, my first choice is hopefully dealing with it in codegen without a custom ukernel, because I feel that work is going to be more long-term useful and less throw-away. But doing so will also teach us things we don't currently know, which could make us revisit and do a ukernel.

@Max191 OK, I don't understand all of these figures, but I'd say that ultimately if your model is mostly matmuls (or some special case of matmuls, like matvecs) then it comes down to how many billion scalar ops are there to execute, and how many we can execute per second (Gop/s) in each arithmetic path (i8, f16, f32, etc). On most recent x86 CPUs (with avx512vnni, like here), the i8 path gives ~ 2x more Gop/s than the f32 path, so it's a safe bet to say it will perform best e2e.

bjacob commented 1 year ago

@stellaraccident - Or, if we really want to achieve the best impact for the effort on this very specific use case in the short term, it's actually OK to just bite the bullet and do a ukernel for this fusion, but we have to be clear that it is likely throw-away work ; maybe that's OK because in exchange for that, we get to de-risk this whole thing , we know exactly what fusion we need to match as a ukernel and what the ukernel needs to do.

But if we are going to do that, one thing isn't adding up. If we are doing a one-off for a particular CPU that has support for avx512bf16 but not avx512fp16, and we are dequantizing to float just as a way of implementing i4 quantized arithmetic, then really we should instead be dequantizing to bfloat16 and accumulating into float32.

bjacob commented 1 year ago

What isn't clear to me in this case is how the quantization would need to change or be specified to get us into a pure int world.

I would like to help with that, if I can be connected with the stakeholders, whenever that's timely.

stellaraccident commented 1 year ago

Yeah, in this case, the models are being borrowed from the GPU side vs considered fresh for a specific microarchitecture.

@bjacob I think we'd be very much open to recommendations on the path forward vs being prescriptive. As with most things ML, we may need all of it at the end of the day, but how you get there makes a big difference. Let me know if you'd like to sync and strategize.

Max191 commented 1 year ago

I uploaded the IR dumps and disassemblies here: https://drive.google.com/drive/folders/1g5M2GVHQjzZH_-xEHZnIvb5SbLhjSxAg?usp=sharing

bjacob commented 1 year ago

Excerpt from quantized_matmul_f16_objdump.txt (from @Max191 's upload in the previous comment), doing the loads of i8 and conversion to f16, from the first generic in the source:

    ^bb0(%in: i8, %in_0: f16, %in_1: f16, %out: f16):
      %9 = arith.extui %in : i8 to i32
      %10 = arith.uitofp %9 : i32 to f16
     [...]

Disassembly excerpt:

    18be:   62 31 7d 08 c5 dd 03    vpextrw r11d,xmm21,0x3
    18c5:   62 b1 7d 08 c5 dd 04    vpextrw ebx,xmm21,0x4
    18cc:   62 61 0e 08 2a c0       vcvtsi2ss xmm24,xmm14,eax
    18d2:   62 61 0e 08 2a c9       vcvtsi2ss xmm25,xmm14,ecx
    18d8:   c5 f9 c5 c6 06          vpextrw eax,xmm6,0x6
    18dd:   62 b1 7d 08 c5 cd 05    vpextrw ecx,xmm21,0x5
    18e4:   62 61 0e 08 2a d2       vcvtsi2ss xmm26,xmm14,edx
    18ea:   62 61 0e 08 2a de       vcvtsi2ss xmm27,xmm14,esi
    18f0:   62 61 0e 08 2a e7       vcvtsi2ss xmm28,xmm14,edi
    18f6:   62 41 0e 08 2a e8       vcvtsi2ss xmm29,xmm14,r8d
    18fc:   62 e1 7d 08 7e ea       vmovd  edx,xmm21
    1902:   62 41 0e 08 2a f1       vcvtsi2ss xmm30,xmm14,r9d
    1908:   0f b7 d2                movzx  edx,dx
    190b:   62 e1 0e 08 2a fa       vcvtsi2ss xmm23,xmm14,edx
    1911:   62 c1 0e 08 2a f2       vcvtsi2ss xmm22,xmm14,r10d
    1917:   62 c1 0e 08 2a e6       vcvtsi2ss xmm20,xmm14,r14d
    191d:   62 c1 0e 08 2a db       vcvtsi2ss xmm19,xmm14,r11d
    1923:   62 e1 0e 08 2a d3       vcvtsi2ss xmm18,xmm14,ebx
    1929:   62 e1 0e 08 2a c9       vcvtsi2ss xmm17,xmm14,ecx
    192f:   c4 e3 79 1d ed 04       vcvtps2ph xmm5,xmm5,0x4
    1935:   62 b1 7d 08 c5 cd 06    vpextrw ecx,xmm21,0x6
    193c:   c4 e3 79 1d db 04       vcvtps2ph xmm3,xmm3,0x4
    1942:   62 93 7d 08 1d c7 04    vcvtps2ph xmm31,xmm0,0x4
    1949:   62 31 7d 08 c5 c5 07    vpextrw r8d,xmm21,0x7
    1950:   c4 63 79 1d d8 04       vcvtps2ph xmm0,xmm11,0x4
    1956:   62 e1 0e 08 2a e8       vcvtsi2ss xmm21,xmm14,eax
    195c:   c4 e3 79 1d d1 04       vcvtps2ph xmm1,xmm2,0x4
    1962:   c4 43 79 1d d2 04       vcvtps2ph xmm10,xmm10,0x4

The problem that I can see here is that the vcvtsi2ss are scalar, not vector ops, converting one integer value to f32 at a time. (You might ask why go through f32 at all, when the source only asked for a i8->f16 conversion? I think that's fair as x86 does not seem to have a direct instruction for that, the F16C extension only bringing conversions between f16 and other floating point types).

The IR print-after-all log (ir_dump_f16.mlir) from @Max191 's upload shows that in LLVM IR these llvm.uitofp ops are still vector:

    %312 = llvm.mlir.undef : !llvm.array<8 x array<1 x vector<64xf16>>>
    %313 = llvm.uitofp %304 : vector<64xi32> to vector<64xf16>
    %314 = llvm.insertvalue %313, %312[0, 0] : !llvm.array<8 x array<1 x vector<64xf16>>> 
    %315 = llvm.uitofp %305 : vector<64xi32> to vector<64xf16>
    %316 = llvm.insertvalue %315, %314[1, 0] : !llvm.array<8 x array<1 x vector<64xf16>>> 
    %317 = llvm.uitofp %306 : vector<64xi32> to vector<64xf16>
    %318 = llvm.insertvalue %317, %316[2, 0] : !llvm.array<8 x array<1 x vector<64xf16>>> 
    %319 = llvm.uitofp %307 : vector<64xi32> to vector<64xf16>
    %320 = llvm.insertvalue %319, %318[3, 0] : !llvm.array<8 x array<1 x vector<64xf16>>> 
    %321 = llvm.uitofp %308 : vector<64xi32> to vector<64xf16>
    %322 = llvm.insertvalue %321, %320[4, 0] : !llvm.array<8 x array<1 x vector<64xf16>>> 
    %323 = llvm.uitofp %309 : vector<64xi32> to vector<64xf16>
    %324 = llvm.insertvalue %323, %322[5, 0] : !llvm.array<8 x array<1 x vector<64xf16>>> 
    %325 = llvm.uitofp %310 : vector<64xi32> to vector<64xf16>
    %326 = llvm.insertvalue %325, %324[6, 0] : !llvm.array<8 x array<1 x vector<64xf16>>> 
    %327 = llvm.uitofp %311 : vector<64xi32> to vector<64xf16>

So whatever is going wrong, causing this to be scalarized, must be happening in LLVM after "us"... @stellaraccident mentioned that one of you might have the knowledge to help here: @dcaballe @qcolombet @jpienaar ?

dcaballe commented 1 year ago

Taking a quick look I can see that the MLIR conversions:

  %241 = llvm.zext %88 : vector<64xi8> to vector<64xi32>                                                                                                                            
  %313 = llvm.uitofp %241 : vector<64xi32> to vector<64xf16>

is folded by LLVM into:

  %176 = uitofp <64 x i8> %53 to <64 x half>, !dbg !84                                                                                                                                            

which makes sense semantically...

However, when the backend processes this <64 x i8> -> <64 x half> conversion without a native instruction for it, it tries to do <64 x i8> -> <64 x i16> followed by <64 x i16> -> <64 x half>, and given that there is no instruction for the latter, it must go into "panic mode" and scalarize the code with some default emulation for that... It's interesting, though, that it emulates the conversion going through f32, which should give different result than going from i16 to half directly... I wonder why that emulation is not kept in the vector world, though... Perhaps it's just not implemented.

I'd suggest that we open a bug against the x86 backend with a small test with a uitofp <64 x i8> %53 to <64 x half> instruction.

Hopefully that helps!

qcolombet commented 1 year ago

To add up on what @dcaballe said, what happens is:

Long story short, the X86 backend misses a better lowering path for v32i8 -> v32fp16 like @dcaballe "it's just not implemented".

powderluv commented 1 year ago

Can @Max191 implement it with your guidance ? Is it a heavy lift ?

Max191 commented 1 year ago

I have been looking into why we have arith.mulf and vector.reduction instead of vector.fma

The reason is that the vector.contract op that these come from don't lowered to vector.outerproduct first, and this happens because there isn't support for this upstream for multiple reduction dims. I see a few options for how to fix this, but wanted to share here for discussion on what the best solution is.

  1. Fold unit reduction dims at linalg level (after tiling, the outermost reduction dim should be 1 based on my tuning a while ago). I tried this, but it prevents broadcasted reads from happening on the scales and zero points. If we want this to work we would have to somehow recover these broadcasted reads.
  2. Fold the unit dims at the vector level on the vector.contract. We would need to make a new pattern for this, but it does not seem very difficult.
  3. Add support for lowering multi-reduction vector.contract ops into vector.outerproduct
  4. Raise the arith.mulf and vector.reduction into vector.contract after they are flattened, and then lower the new vector.contract ops normally. There is a pattern to do this raising for vector.multi_reduction ops, and it is not hard to add a new pattern to raise regular vector.reduction ops. This seems like a messy way to do it, but it is something I was trying as a quick way to get proof of concept (until I ran into some problem with my pattern that I am debugging).
powderluv commented 1 year ago

also from discord https://discord.com/channels/689900678990135345/1146173056537079919/1146451713616773231

I'd suggest also following @Diego Caballero 's suggestion, I'd suggest that we open a bug against the x86 backend with a small test with a uitofp <64 x i8> %53 to <64 x half> instruction. .
you never know, if you're lucky and that hits the inbox of the right x86 maintainer...
qcolombet commented 1 year ago

Can @Max191 implement it with your guidance ? Is it a heavy lift ?

Sure. It shouldn't be too hard assuming you know which instructions sequence you want to go for. I haven't looked at X86 ISA for a while so that part may take some time to figure out what we want here.

Max191 commented 1 year ago

I'd suggest that we open a bug against the x86 backend with a small test with a uitofp <64 x i8> %53 to <64 x half> instruction.

@dcaballe Is this example sufficient for the bug?

define <8 x half> @main() {
    %int8_p = alloca <8 x i8>
    %int8 = load <8 x i8>, <8 x i8>* %int8_p
    %fp16 = uitofp <8 x i8> %int8 to <8 x half>
    ret <8 x half> %fp16
}
qcolombet commented 1 year ago

I'd suggest that we open a bug against the x86 backend with a small test with a uitofp <64 x i8> %53 to <64 x half> instruction.

@dcaballe Is this example sufficient for the bug?

define <8 x half> @main() {
    %int8_p = alloca <8 x i8>
    %int8 = load <8 x i8>, <8 x i8>* %int8_p
    %fp16 = uitofp <8 x i8> %int8 to <8 x half>
    ret <8 x half> %fp16
}

Yes, that'll work. Just two things:

Max191 commented 1 year ago

I realized I forgot to add here that our target is to be faster than https://github.com/ggerganov/llama.cpp, which on my machine is about 70ms for llama2 7b int4.

I looked into what they do in llama.cpp, and they are always dequantizing into f32 for all quantization sizes, so we should probably start by doing this before we attempt to go into the x86 backend ourselves. I'll post the bug and then I will work on a pattern to convert our i4->f16 dequantizations into i4->f32 for now. Hopefully someone upstream will see our bug and take it on while we do this

Max191 commented 1 year ago

Thanks @qcolombet , I will add the float example and post the bug. I can switch it to 64xi8 as well. I only used 8 to make it a little smaller, but we should probably keep it to be the same as what we see in the model. Also, eventually we want it to work for i4 too, so I think I should probably also switch it to i4 actually.

qcolombet commented 1 year ago

Thanks @qcolombet , I will add the float example and post the bug. I can switch it to 64xi8 as well. I only used 8 to make it a little smaller, but we should probably keep it to be the same as what we see in the model. Also, eventually we want it to work for i4 too, so I think I should probably also switch it to i4 actually.

For the bug, stick to i8 (as opposed to i4) to make it less scary. Then for the number of elements 8 is also a good start (as opposed to 64).

dcaballe commented 1 year ago

https://github.com/llvm/llvm-project/issues/67080

qcolombet commented 9 months ago

@Max191 FYI, we improved the lowering of i18|i16 -> fp16 with https://github.com/llvm/llvm-project/pull/70834. You may want to redo the perf measurements here to see where we stand now.