iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 569 forks source link

Comparison with GGML for quantized matmuls #14951

Open Max191 opened 12 months ago

Max191 commented 12 months ago

This issue is to provide some context and discussion for current quantized matmuls compared against what is done for GGML in llama.cpp.

So far, I have looked at the i4 quantization case. In llama2, we have been looking at matvecs, where the corresponding inputs to GGMLs matmul kernel would be a LHS quantized matrix times a RHS unquantized vector.

The matmul is defined in ggml_compute_forward_mul_mat, and the functions for handling vector dot products for various quantization types are defined here

The critical code is in how it handles these vector dot products:

ggml_vec_dot_t    const vec_dot               = type_traits[type].vec_dot;
enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
[GGML_TYPE_Q4_1] = {
    .type_name                = "q4_1",
    .blck_size                = QK4_1,
    .type_size                = sizeof(block_q4_1),
    .is_quantized             = true,
    .to_float                 = (ggml_to_float_t) dequantize_row_q4_1,
    .from_float               = quantize_row_q4_1,
    .from_float_reference     = (ggml_from_float_t) quantize_row_q4_1_reference,
    .vec_dot                  = ggml_vec_dot_q4_1_q8_1,
    .vec_dot_type             = GGML_TYPE_Q8_1,
},

...

[GGML_TYPE_Q8_1] = {
    .type_name                = "q8_1",
    .blck_size                = QK8_1,
    .type_size                = sizeof(block_q8_1),
    .is_quantized             = true,
    .from_float               = quantize_row_q8_1,
    .from_float_reference     = (ggml_from_float_t) quantize_row_q8_1_reference,
    .vec_dot_type             = GGML_TYPE_Q8_1,
},

GGML dequantizes the RHS here into the vec_dot_type (GGML_TYPE_Q8_1) before performing the matmul, and then uses ggml_vec_dot_q4_1_q8_1 for the matmul.

The dot product looks like this:

static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
    const int qk = QK8_1;
    const int nb = n / qk;

    assert(n % qk == 0);

    const block_q4_1 * restrict x = vx;
    const block_q8_1 * restrict y = vy;

    ...

    // scalar
    float sumf = 0.0;

    for (int i = 0; i < nb; i++) {
        int sumi = 0;

        for (int j = 0; j < qk/2; ++j) {
            const int v0 = (x[i].qs[j] & 0x0F);
            const int v1 = (x[i].qs[j] >>   4);

            sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
        }

        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
    }

    *s = sumf;

The sum of products sumi is computed in i8, with the i4 LHS being unpacked here into byte sized containers, and then sumi is multiplied by the scales for the LHS x[i].d and RHS y[i].d. Then they handle the zero points after the sum of products by adding GGML_FP16_TO_FP32(x[i].m)*y[i].s

x[i].m is the min value of the LHS quantized group, while y[i].s is equal to sum_j(y[i].qs[j] * y[i].d) (aka the scale times the sum of the quantized weights. This is computed in quantize_row_q8_1).

So the rhs is symmetrically quantized dynamically, and then the dot product is reassociated from:

((lhs_q*lhs_scale)+lhs_zp) • (rhs_q * rhs_scale)

to

(lhs_q • rhs_q) * lhs_scale * rhs_scale + (lhs_zp * sum(rhs_q))

It is also worth noting that GGML does ((lhs_q*lhs_scale)+lhs_zp) and we do ((lhs_q-lhs_zp)*lhs_scale), so I think they save 1 multiplication in the reassociated form with this quantization difference.

The quantization scheme is essentially the same, however, and the main thing that GGML does to do these multiplications in i8 is the dynamic requantization of the RHS vector back into i8 before doing the matmul/matvec.

Max191 commented 12 months ago

I created some linalg IR to represent this reassociation so we can have something to look at and start iterating on https://gist.github.com/Max191/c76474d1d484cb05d191543877110d28

One thing to note that is slightly different from what it seems GGML is doing in computing the scaled sums of the vector:

%vec_scaled_sums = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
                                                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], 
                                   iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 
                                   ins(%vector : tensor<1x1x32x128xf32>) 
                                   outs(%2 : tensor<1x1x32xf32>) {
^bb0(%vec: f32, %out: f32):
  %sum = arith.addf %vec, %out : f32
  linalg.yield %sum : f32
} -> tensor<1x1x32xf32>

Here I simply use the sum of the original unquantized values. However GGML first quantizes, then computes the sum of the quantized values times the scales:

static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
    assert(QK8_1 == 32);
    assert(k % QK8_1 == 0);
    const int nb = k / QK8_1;

    for (int i = 0; i < nb; i++) {
        float amax = 0.0f; // absolute max

        for (int j = 0; j < QK8_1; j++) {
            const float v = x[i*QK8_1 + j];
            amax = MAX(amax, fabsf(v));
        }

        const float d = amax / ((1 << 7) - 1);
        const float id = d ? 1.0f/d : 0.0f;

        y[i].d = d;

        int sum = 0;

        for (int j = 0; j < QK8_1/2; ++j) {
            // GGML quantizing values here
            const float v0 = x[i*QK8_1           + j]*id;
            const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;

            y[i].qs[          j] = roundf(v0);
            y[i].qs[QK8_1/2 + j] = roundf(v1);

            // Then computing the sum of the quantized values
            sum += y[i].qs[          j];
            sum += y[i].qs[QK8_1/2 + j];
        }
        // Multiply by scales
        y[i].s = sum*d;
    }
}

The GGML implementation loses some information here, and from my testing, the results of the matmul are more precise when computing the scaled sum before quantization rather than after quantization.

Max191 commented 11 months ago

I made a draft for a pattern that handles the transformation into the above form here https://github.com/openxla/iree/pull/14975. The pattern needs some reworking, but we can start looking at the codegen now.

I'll also include here the IR dump from compiling the following IR:

builtin.module {
  func.func @quantized_matmul(%arg0: tensor<11008x32x128xi4>, %arg1: tensor<11008x32x1xf32>, %arg2: tensor<11008x32x1xf32>, %arg3: tensor<1x1x32x128xf32>) -> tensor<1x1x11008xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %4 = tensor.empty() : tensor<1x1x11008xf32>
    %5 = tensor.empty() : tensor<11008x32x128xf32>
    %6 = linalg.fill ins(%cst : f32) outs(%4 : tensor<1x1x11008xf32>) -> tensor<1x1x11008xf32>

    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor<11008x32x128xi4>, tensor<11008x32x1xf32>, tensor<11008x32x1xf32>) outs(%5 : tensor<11008x32x128xf32>) {
    ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32):
      %9 = arith.extui %in : i4 to i32
      %10 = arith.uitofp %9 : i32 to f32
      %11 = arith.subf %10, %in_1 : f32
      %12 = arith.mulf %11, %in_0 : f32
      linalg.yield %12 : f32
    } -> tensor<11008x32x128xf32>
    %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg3, %7 : tensor<1x1x32x128xf32>, tensor<11008x32x128xf32>) outs(%6 : tensor<1x1x11008xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %9 = arith.mulf %in, %in_0 : f32
      %10 = arith.addf %9, %out : f32
      linalg.yield %10 : f32
    } -> tensor<1x1x11008xf32>
    return %8 : tensor<1x1x11008xf32>
  }
}

IR dump: https://drive.google.com/file/d/1VrqHykySmdWEt-YSek01tt3EqxSpBDhW/view?usp=sharing

bjacob commented 11 months ago

Extracting from the 72 MB dump.mlir the before/after of the FuseDequantizationMatmul pass from #14975:

https://gist.github.com/bjacob/0242746ba4b90c6c8a0314326b789f80

Max191 commented 11 months ago

I just realized I uploaded the dump with FoldUnitExtentDims running after fusion. Here is another dump where the unit dims are folded away. Sorry for the mistake before, this one should be easier to read

https://drive.google.com/file/d/17FebNFZkxlb_WJCAKD98WqYWLZVgO1QA/view?usp=sharing

And the results from the fusion pass: https://gist.github.com/Max191/ae5f79c327dda99bafd273799f0a5cbe

Max191 commented 11 months ago

After some changes to the reassociation pass, the text generation is looking good with the reassociated quantized matmul. Now we can start to work on getting an efficient lowering.

Here is an IR dump for what we have now: https://drive.google.com/file/d/1WtJ5giqVItLLmH2jMhlrtYIzgNeXqeSo/view?usp=sharing

And a gist with some more isolated information: https://gist.github.com/Max191/6ff02702d6b3ad2e760fafa6e53f02ea

It seems to me that the most glaring issue is the lowering of the vector.contract that comes out of the integer matmul (producer in the reassociated form):

// -----// IR Dump After RemoveSingleIterationLoop (iree-codegen-remove-single-iteration-loop) //----- //
func.func @quantized_matmul_dispatch_3_generic_11008x32x128_i16xi4xi32() {
  %cst = arith.constant dense<0.000000e+00> : vector<8xf32>
  %c0_i4 = arith.constant 0 : i4
  %c0_i16 = arith.constant 0 : i16
  %cst_0 = arith.constant dense<0> : vector<8x32xi32>
  %c16 = arith.constant 16 : index
  %c32 = arith.constant 32 : index
  %c8 = arith.constant 8 : index
  %c0 = arith.constant 0 : index
  %c256 = arith.constant 256 : index
  %c128 = arith.constant 128 : index
  %cst_1 = arith.constant 0.000000e+00 : f32
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c256) flags(ReadOnly) : memref<32x128xi16, strided<[128, 1], offset: 128>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %0, 64 : memref<32x128xi16, strided<[128, 1], offset: 128>, #hal.descriptor_type<storage_buffer>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<11008x32x128xi4, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %1, 64 : memref<11008x32x128xi4, #hal.descriptor_type<storage_buffer>>
  %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<32xf32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %2, 64 : memref<32xf32, #hal.descriptor_type<storage_buffer>>
  %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c128) flags(ReadOnly) : memref<32xf32, strided<[1], offset: 32>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %3, 64 : memref<32xf32, strided<[1], offset: 32>, #hal.descriptor_type<storage_buffer>>
  %4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %4, 64 : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>>
  %5 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %5, 64 : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>>
  %6 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : memref<11008xf32, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %6, 64 : memref<11008xf32, #hal.descriptor_type<storage_buffer>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
  %subview = memref.subview %6[%7] [32] [1] : memref<11008xf32, #hal.descriptor_type<storage_buffer>> to memref<32xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %subview_2 = memref.subview %1[%7, 0, 0] [32, 32, 128] [1, 1, 1] : memref<11008x32x128xi4, #hal.descriptor_type<storage_buffer>> to memref<32x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %subview_3 = memref.subview %4[%7, 0] [32, 32] [1, 1] : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>> to memref<32x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %subview_4 = memref.subview %5[%7, 0] [32, 32] [1, 1] : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>> to memref<32x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  scf.for %arg0 = %c0 to %c32 step %c8 {
    %8 = scf.for %arg1 = %c0 to %c128 step %c16 iter_args(%arg2 = %cst_0) -> (vector<8x32xi32>) {
      %22 = vector.transfer_read %0[%c0, %arg1], %c0_i16 {in_bounds = [true, true]} : memref<32x128xi16, strided<[128, 1], offset: 128>, #hal.descriptor_type<storage_buffer>>, vector<32x16xi16>
      %23 = vector.transfer_read %subview_2[%arg0, %c0, %arg1], %c0_i4 {in_bounds = [true, true, true]} : memref<32x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8x32x16xi4>
      %24 = arith.extsi %22 : vector<32x16xi16> to vector<32x16xi32>
      %25 = arith.extui %23 : vector<8x32x16xi4> to vector<8x32x16xi32>
      %26 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %24, %25, %arg2 : vector<32x16xi32>, vector<8x32x16xi32> into vector<8x32xi32>
      scf.yield %26 : vector<8x32xi32>
    }
    %9 = vector.transfer_read %2[%c0], %cst_1 {in_bounds = [true]} : memref<32xf32, #hal.descriptor_type<storage_buffer>>, vector<32xf32>
    %10 = vector.broadcast %9 : vector<32xf32> to vector<8x32xf32>
    %11 = vector.transfer_read %3[%c0], %cst_1 {in_bounds = [true]} : memref<32xf32, strided<[1], offset: 32>, #hal.descriptor_type<storage_buffer>>, vector<32xf32>
    %12 = vector.broadcast %11 : vector<32xf32> to vector<8x32xf32>
    %13 = vector.transfer_read %subview_3[%arg0, %c0], %cst_1 {in_bounds = [true, true]} : memref<32x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8x32xf32>
    %14 = vector.transfer_read %subview_4[%arg0, %c0], %cst_1 {in_bounds = [true, true]} : memref<32x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8x32xf32>
    %15 = arith.sitofp %8 : vector<8x32xi32> to vector<8x32xf32>
    %16 = arith.mulf %15, %10 : vector<8x32xf32>
    %17 = arith.mulf %16, %13 : vector<8x32xf32>
    %18 = arith.mulf %14, %13 : vector<8x32xf32>
    %19 = arith.mulf %18, %12 : vector<8x32xf32>
    %20 = arith.subf %17, %19 : vector<8x32xf32>
    %21 = vector.multi_reduction <add>, %20, %cst [1] : vector<8x32xf32> to vector<8xf32>
    vector.transfer_write %21, %subview[%arg0] {in_bounds = [true]} : vector<8xf32>, memref<32xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  }
  return
}

I believe if we can tile the d1 (number of groups) dimension to 1 here, then we should be able to fold it away in the inner tile, and the resulting vector contract would look like:

%26 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %24, %25, %arg2 : vector<16xi32>, vector<8x16xi32> into vector<8xi32>

Which looks like a normal matvec. Then it would be much easier to get a good lowering out of this contract

Edit: Here is the branch with the pass: https://github.com/Max191/iree/tree/reassociate_quantized_matmul

Max191 commented 11 months ago

After doing the tiling described above, the vector.contract has a unit dimension for the number of groups:

%27 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, 
                                        affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 
                                        affine_map<(d0, d1, d2) -> (d0, d1)>], 
                        iterator_types = ["parallel", "parallel", "reduction"], 
                        kind = #vector.kind<add>} 
                        %25, %26, %arg3 : vector<1x16xi32>, vector<8x1x16xi32> into vector<8x1xi32>

Now, we can try to use VectorContractCustomKernels on this contract. I assume we can either fold away this unit dimension to make matching easier, or just add a matching case for this form. @bjacob Which do you think is better? I'm not sure if there is already a pattern that does this folding (there is one if the unit dimension were the outermost dimension, but maybe not for inner dimensions), but we can add one if need be

Here is the dump with the new tiling for reference as well: https://drive.google.com/file/d/1awAGq2gMDGnk6xEsWwg8KY3cDfZ0ijVV/view?usp=sharing

bjacob commented 11 months ago

Yes, it would be ideal to let a pattern do this folding. I don't know either if one exists (@MaheshRavishankar might know). From a quick scan at https://github.com/llvm/llvm-project/blob/8f8f4493d5fc6a025fff678f030da73dcfd8baa7/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp it looks like it might be only for dripping the leading dimension like you noted.

Once the unit dim is dropped, this vector.contract will be more or less in the right shape to be matched by the existing code in VectorContractCustomKernels.

Like we discussed yesterday, we need to decide which specific shape to match in VectorContractCustomKernels, so you will have to tile/unroll to that shape.

Let's adopt the shape convention in use in VectorContractCustomKernels here: shapes are MxKxN where M is the number of rows of LHS (so it's 1 in a vector*matrix product), K is the reduction dimension size (the dimension shared by the LHS and the RHS and not appearing in the result) and N is the number of columns of RHS.

So for instance, if the vector.contract is as above

%26 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %24, %25, %arg2 : vector<16xi32>, vector<8x16xi32> into vector<8xi32>

Then here, M=1, K=16, N=8.

More generally, in any matrix x vector case, M=1, K=(length of input vector), N=(length of output vector).

OK, so now to choose the specific shape to match in VectorContractCustomKernels, this depends on the ISA. Let's focus on x86_64 for now. Since the element types are i16, i16, i32, we want to use these instructions, dictating these MxKxN shapes, depending on which CPU feature is available:

CPU feature Instruction to use Instruction MxKxN shape
MMX PMADDWD + add accumulator 1x2x2
SSE2 PMADDWD + add accumulator 1x2x4
AVX2 VPMADDWD + add accumulator 1x2x8
AVX-512 VPMADDWD + add accumulator 1x2x16
AVX-512-VNNI VPDPWSSD 1x2x16

Just start with the VNNI case since it's supported on recent CPUs (including AMD Zen4), and it's simplest (single instruction, no need for separate accumulator addition) and of course it'll perform best (thanks to being a single instruction).

The above is just the shape of a single instruction. Normally, we match VectorContractCustomKernels for widened shapes where the narrow dimension (M=1) is widened typically to match N, putting several instructions side by side, so that we don't awkwardly have just a couple of scalars sitting alone in the LHS register. But here, there is a complication specific to these instructions: they don't have broadcast semantics, so if our LHS register contains multiple rows, we have to add separate instructions to broadcast specific lanes. The problem is that from AVX2 onwards, that would mean broadcasting in >= 256bit registers across 128bit boundaries, which is relatively slow. To avoid too much headache for now, let's just stick to the above narrow shapes. The resulting code will not be perfectly efficient, but basically to fix that one needs to change the layout of the accumulator tile to avoid having to broadcast across 128bit boundaries, and that can only be done efficiently if we own the codegen for not just the vector.contract, but the whole enclosing loop. In other words, this is a reason to switch to ukernels at that point. See how the ukernels do it.

Once you've achieved the tiling to the above shapes, I can provide more specific implementation guidance in VectorContractCustomKernels. Here is an example of a similar enough case done on ARM with both inline asm and intrinsics variants:

inline asm: https://github.com/openxla/iree/blob/8910e1153a5f510f5a4bbf627cd8b21b3b6ee3e0/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp#L401-L441

intrinsics: https://github.com/openxla/iree/blob/8910e1153a5f510f5a4bbf627cd8b21b3b6ee3e0/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp#L1003-L1117

The caveat with intrinsics is that this relies on particular intrinsics being reflected in some AVX dialect. Some abstraction burden. So think of inline asm as a shortcut here, not necessarily something to be done out of an optimization concern.

bjacob commented 11 months ago

Follow-up to a conversation with @Max191: we need to change 2 things here to maximize performance:

  1. Let the input RHS element type be i4, moving the i4->i16 extension into the asm kernel implementation.
  2. Tile the vector.contract differently to better match what VPDPWSSD is doing, to avoid the need for broadcasts.

Of the two, 1. is lower effort than 2. and will already provide a nice boost, so it's OK to try 1 only at first.

Elaborating on 2 if you want to give it a try:

So far we have been tiling a big vector-times-matrix vector.contract into smaller vector-times-matrix. But VPDPWSSD isn't readily a vector-times-matrix, instead it is 16 separate dot-products. To implement a vector-times-matrix, the vector needs to be broadcasted first, then that can be fed to VPDPWSSD.

We can tile this differently to get the opportunity to use VPDPWSSD's 16 independent dot-products, by taking advantage of the mutual independence along the K (reduction) dimension.

Say that our workload is a (M=1, K=128, N=8) vector-times matrix. That is to say: we have an input lhs vector of length K=128 and we are multiplying it by a 128x8 rhs matrix, accumulating into a vector of length 8.

As C code:

// Code snippet #1

for (int i = 0; i < 8; i++) {
  for (int k = 0; k < 128; k++) {
    output[i] += lhs[k] * rhs[i][k];
  }
}

Let's introduce 16 local accumulators for each i, and tile the loop on k by 16, accumulating into these 16 local accumulators instead of directly accumulating into output[i], then add a second reduction-only loop adding these 16 local accumulators into output[i]:

// Code snippet #2

for (int i = 0; i < 8; i++) {
  int32_t acc[16] = {0};
  for (int k = 0; k < 128; k += 16) {
    for (int p = 0; p < 16; p++) {
      acc[p] += lhs[k + p] * rhs[i][k + p];
    }
  }
  for (int p = 0; p < 16; p++) {
    output[i] += acc[p];
  }
}

Then the point is that this hot loop is now exactly just one VPDPWSSD instruction. As a dummy "intrinsic":

// Code snippet #3

void VPDPWSSD(int32_t* acc, const int16_t* lhs, const int16_t* rhs) {
  for (int p = 0; p < 16; p++) {
    acc[p] += lhs[p] * rhs[i][p];
  }
}

Now our vector.contract is just

// Code snippet #4

for (int i = 0; i < 8; i++) {
  int32_t acc[16] = {0};
  for (int k = 0; k < 128; k += 16) {
    VPDPWSSD(acc, &lhs[k], &rhs[i][k]);
  }
  for (int p = 0; p < 16; p++) {
    output[i] += acc[p];
  }
}

So in terms of vector.contract, starting from the original vector-times-matrix (represented in Code snippet #1), which would have 1 reduction iterator on k and 1 parallel iterator on i), we are rewriting it into a new vector.contract (represented in Code snippet #2), with an additional parallel iterator on p, achieved by splitting the existing reduction iterator on k into 2 iterators, but leaving one of them parallel, deferring the reduction into a 2nd reduction-only vector.contract (that 2nd for loop on p at the end of Code snippet #2). Then the VectorContractCustomKernels can match that VPDPWSSD vector.contract (Code snippet #3), performing the rewrite from snippet #2 to #4.

Does that make sense?

MaheshRavishankar commented 11 months ago

I think the LLVMCPUSplitReductionPass (https://github.com/openxla/iree/blob/main/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp) does exactly this transformation at Linalg level. I think setting the right loweringConfig gets us here... @Max191 if you need we can code-golf on this together to get this form, but look at the examples of this pass to see how this works...

bjacob commented 11 months ago

Thanks @MaheshRavishankar for the LLVMCPUSplitReductionPass tip.

Tought a bit more about this issue:

1. Refinement of the above VectorContractCustomKernel approach.

Above I was suggesting, after applying LLVMCPUSplitReductionPass, tiling this down to the shape of one VPDPWSSD instruction. That is, a reduction dimension of size 2, and a parallel dimension of size 16, but the other parallel dimension of size 8 has been completely tiled down to size 1 so it's gone here. That was the correct thing to do as far as one VPDPWSSD instruction is concerned, but in order to get the best kernel, we need to preserve the original parallel dimension of size 8 all the way down to this VectorContractCustomKernel -- so we need to write a VectorContractCustomKernel for a group of 8 VPDPWSSD instructions sharing a common LHS vector, but each having their own separate RHS and Accumulator, as described by this "intrinsic":

void EIGHT_VPDPWSSD(int32_t* acc, const int16_t* lhs, const int16_t* rhs) {
  for (int i = 0; i < 8; i++) {
    for (int p = 0; p < 16; p++) {
      acc[i][p] += lhs[p] * rhs[i][p];
    }
  }
}

This is going to perform better for 2 reasons: (1) reuse of the LHS value will mean a higher ratio of useful arithmetic to memory-access instructions. ; (2) mutually independent arithmetic instructions will better utilize the CPU's pipelines.

2. Taking a step back --- maybe this can be a normal matmul for data-tiling and ukernels after all?

Maybe I didn't realize initially that the group size was 128, when I gave up trying to make this workload go through the existing fast paths that we have for matmuls. Now that thanks to your investigations here I better understand what we are doing, and in particular that, IIUC, within each group of size 128 we are really doing just a plain old matmul, I wonder if we could just follow that route. That is, tile the K (reduction) dimension down to 128, perform any necessary linalg-to-linalg rewrites to bring each tile (now dealing with a single shared quantization scale ) into a plain old int16 x int4 matmul, let that go through data tiling and ukernels. We'd still add a dedicated new ukernel for that new combination of element types, int16 x int4, and we would still apply there what we have just learned about properly using x86 VPDPWSSD for vector-times-matrix cases (we could then backport that new understanding to improving on existing vector-times-matrix cases in data-tiling and ukernels).

Max191 commented 11 months ago
  • Let the input RHS element type be i4, moving the i4->i16 extension into the asm kernel implementation.

  • Tile the vector.contract differently to better match what VPDPWSSD is doing, to avoid the need for broadcasts.

Now that I am back, I have started so far by doing 2, but without any changes to codegen. The whole dot product can be done with 4 VPDPWSSD instructions, so I am doing everything in a 1x128x1 kernel by passing 4 RHS and 4 LHS registers, and horizontally adding after the 4 VPDPWSSD instructions. I did not do 1 yet because with K=128, we can actually generate pretty good assembly for the i4->i16 extension.

With this change, we are pretty close to the target latency already (70ms), with a full model latency of 83ms (down from 105ms without integer math). However, I am having some correctness problem with the new kernel (I think I am doing the horizontal add wrong) that I need to fix before I confirm this new number.

Eventually I think we will want to use LLVMCPUSplitReductionPass, and make the kernel's K size smaller, because we are limited by the number of input registers if we start increasing the N size in the kernel as you have described above. In fact, the original parallel N size is very large (the smallest of these vecmats has N=4096 I believe), so we can increase N as much as is possible in the kernel.

We may also still want to move the i4-i16 extension into the kernel, but I'm not sure how much performance boost we will get if the K size of the kernel is sufficiently large.

I will update again when I figure out this correctness issue I am having

bjacob commented 11 months ago

Thanks for the update.

Now that I am back, I have started so far by doing 2, but without any changes to codegen. The whole dot product can be done with 4 VPDPWSSD instructions, so I am doing everything in a 1x128x1 kernel by passing 4 RHS and 4 LHS registers, and horizontally adding after the 4 VPDPWSSD instructions. I did not do 1 yet because with K=128, we can actually generate pretty good assembly for the i4->i16 extension.

Some unrolling like that is what I had in mind in Part 1 of my previous comment, except I was thinking you will get higher performance by unrolling along the parallel instead of the reduction dimension here --- e.g. switching from 1x128x1 to 1x32x4 or whatever. Basically taking some factor out of the middle dimension (that 128) into one of the other dimensions that is currently 1.

bjacob commented 11 months ago

Summary of quick meeting with @Max191 :

Max191 commented 11 months ago

I have fixed the correctness issue. It came from the way I was tiling the number of groups dimension to 1 in order to fold away that dimension in the vector.contract.

Here is the IR before and after we tile the number of groups: https://gist.github.com/Max191/750a9d7e325e4fe3a712a193382ba86c

Because it was tiling through tile and fuse, there is a linalg.fill op on the output of the second linalg.generic that gets tiled and fused when we don't want it to. This was causing the output to be reinitialized to 0 for every group in the reduction.

Instead, I am now tiling the number of groups in the 3rd level of tiling, which produces correct results again with 75-85ms performance. However, this requires us to remove the check in tile size verification that all 3rd level tiling are reductions. I am also seeing some weird behavior with the compilation (the compilation is Killed when compiling e2e, but works if the model is compiled to flow and then from flow).

With this in mind, I think it would be worthwhile to do at least the unrolling step of what @bjacob was suggesting in idea 2. If we unroll the vecmat linalg.generic along the group dimension, then we no longer need to tile in the third tiling level, and it will make some progress towards eventually just using the matmul pipelines. We can probably just do this unrolling in the FuseDequantizationMatmul pass where we create this op to begin with. @MaheshRavishankar @bjacob, let me know what you guys think.

MaheshRavishankar commented 11 months ago

Let's unwrap this a bit today on a video call