iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.72k stars 598 forks source link

[LLVMCPU] Data tiling of group quantized matmul on CPU #14337

Open qedawkins opened 1 year ago

qedawkins commented 1 year ago

The following is an example of a group quantized matmul found in Vicuna (pulled from https://github.com/nod-ai/SHARK/issues/1630, closely related to the i4 IR attached in #12859).

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
module {
  func.func @something(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> {
    %cst = arith.constant dense_resource<__elided__> : tensor<4096x32x1xf32>
    %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32x1xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<1x1x4096xf32>
    %1 = tensor.empty() : tensor<4096x32x128xf32>
    %2 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
    %3 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, 
%cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%1 : tensor<4096x32x128xf32>) {
    ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32):
      %5 = arith.extui %in : i8 to i32
      %6 = arith.uitofp %5 : i32 to f32
      %7 = arith.subf %6, %in_3 : f32
      %8 = arith.mulf %7, %in_2 : f32
      linalg.yield %8 : f32
    } -> tensor<4096x32x128xf32>
    %4 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "red
uction"]} ins(%arg1, %3 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%2 : tensor<1x1x4096xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %5 = arith.mulf %in, %in_2 : f32
      %6 = arith.addf %5, %out : f32
      linalg.yield %6 : f32
    } -> tensor<1x1x4096xf32>
    return %4 : tensor<1x1x4096xf32>
  }
}

(A trace of the vicuna model is linked in the above SHARK issue for reference) image

The tiling prescribed by the group quantization scheme in this case currently will prevent strategies like mmt4d from applying on CPU. The goal of this issue is to focus on how to select a canonical format for representing this operation (whether that requires a linalg_ext type op) for codegen purposes, as well as how to facilitate the dispatch formation, either automatically or as a preprocessing step.

(dispatch_37: https://storage.googleapis.com/shark_tank/dan/second_vic_int8_dispatches/module_forward_dispatch_37.mlir)

hanhanW commented 1 year ago

So the reduction is res[b][n][m] = lhs[b][n][i][j] * rhs[m][i][j]. Maybe we can flatten/linearize reduction dims, and it becomes res[b][n][m] = lhs[b][n][k] * rhs[m][k]. If we tile b dimension with tile size 1, it becomes a regular matmul. Then we can apply mmt4d tricks or microkernels. However, we'll need to figure out how to set encodings at flow level. This seems to be an extension about setting encodings on ContractionOpInterface ops, which needs more studies.

qedawkins commented 1 year ago

So the reduction is res[b][n][m] = lhs[b][n][i][j] * rhs[m][i][j]. Maybe we can flatten/linearize reduction dims, and it becomes res[b][n][m] = lhs[b][n][k] * rhs[m][k]. If we tile b dimension with tile size 1, it becomes a regular matmul. Then we can apply mmt4d tricks or microkernels. However, we'll need to figure out how to set encodings at flow level. This seems to be an extension about setting encodings on ContractionOpInterface ops, which needs more studies.

Right, part of the problem here is fusing the two to make sure we don't materialize the full dequantized weight tensor in memory. Another way to write this in linalg is with tensor.extract on the scales + zero points and then flattening the reduction dims, but I would have expected keeping the group size as an independent loop would be easier to manage.

MaheshRavishankar commented 1 year ago

I looked into it a little bit more. Tried to document what I was trying. Lets start with the input IR.

module {
  func.func @something(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> {
    %cst = arith.constant dense_resource<__elided__> : tensor<4096x32x1xf32>
    %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32x1xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<1x1x4096xf32>
    %1 = tensor.empty() : tensor<4096x32x128xf32>
    %2 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
    %3 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                                        affine_map<(d0, d1, d2) -> (d0, d1, 0)>, 
                                        affine_map<(d0, d1, d2) -> (d0, d1, 0)>, 
                                        affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
        iterator_types = ["parallel", "parallel", "parallel"]} 
        ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>)
        outs(%1 : tensor<4096x32x128xf32>) {
    ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32):
      %5 = arith.extui %in : i8 to i32
      %6 = arith.uitofp %5 : i32 to f32
      %7 = arith.subf %6, %in_3 : f32
      %8 = arith.mulf %7, %in_2 : f32
      linalg.yield %8 : f32
    } -> tensor<4096x32x128xf32>
    %4 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
                                        affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
                                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
        iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
        ins(%arg1, %3 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) 
        outs(%2 : tensor<1x1x4096xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %5 = arith.mulf %in, %in_2 : f32
      %6 = arith.addf %5, %out : f32
      linalg.yield %6 : f32
    } -> tensor<1x1x4096xf32>
    return %4 : tensor<1x1x4096xf32>
  }
}

This is a bit hard to work though cause it is not in a very canonical form. Getting this to a canonical form will need

After that the computation looks like this.

module {
  func.func @something(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x4096xf32>) -> tensor<1x4096xf32> {
    %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
    %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<1x4096xf32>
    %1 = tensor.empty() : tensor<4096x32x128xf32>
    %2 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<1x4096xf32>) -> tensor<1x4096xf32>
    %3 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                         affine_map<(d0, d1, d2) -> (d0, d1)>, 
                         affine_map<(d0, d1, d2) -> (d0, d1)>, 
                         affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
        iterator_types = ["parallel", "parallel", "parallel"]} 
        ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>)
        outs(%1 : tensor<4096x32x128xf32>) {
    ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32):
      %5 = arith.extui %in : i8 to i32
      %6 = arith.uitofp %5 : i32 to f32
      %7 = arith.subf %6, %in_3 : f32
      %8 = arith.mulf %7, %in_2 : f32
      linalg.yield %8 : f32
    } -> tensor<4096x32x128xf32>
    %4 = tensor.collapse_shape %3 [[0], [1, 2]] : tensor<4096x32x128xf32> to tensor<4096x4096xf32>
    %5 = linalg.matmul {
        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
                         affine_map<(d0, d1, d2) -> (d2, d1)>,
                         affine_map<(d0, d1, d2) -> (d0, d1)>],
        iterator_types = ["parallel", "parallel", "reduction"]}
        ins(%arg1, %4 : tensor<1x4096xf32>, tensor<4096x4096xf32>) 
        outs(%2 : tensor<1x4096xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %6 = arith.mulf %in, %in_2 : f32
      %7 = arith.addf %5, %out : f32
      linalg.yield %7 : f32
    } -> tensor<1x4096xf32>
    return %5 : tensor<1x4096xf32>
  }
}

Before trying to design a new operation, it should be asked if the function in its form can be code-generated efficiently if this was in a single dispatch. The immediate issue is tensor.collapse_shape that prevents tile and fuse. One way to avoid this would be to use this representation.

module {
  func.func @something(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x4096xf32>) -> tensor<1x32x128xf32> {
    %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
    %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<1x32x128xf32>
    %1 = tensor.empty() : tensor<4096x32x128xf32>
    %2 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<1x32x128xf32>) -> tensor<1x32x128xf32>
    %3 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                         affine_map<(d0, d1, d2) -> (d0, d1)>, 
                         affine_map<(d0, d1, d2) -> (d0, d1)>, 
                         affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
        iterator_types = ["parallel", "parallel", "parallel"]} 
        ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>)
        outs(%1 : tensor<4096x32x128xf32>) {
    ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32):
      %5 = arith.extui %in : i8 to i32
      %6 = arith.uitofp %5 : i32 to f32
      %7 = arith.subf %6, %in_3 : f32
      %8 = arith.mulf %7, %in_2 : f32
      linalg.yield %8 : f32
    } -> tensor<4096x32x128xf32>
    %4 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
                         affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>,
                         affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
        iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
        ins(%arg1, %3 : tensor<1x4096xf32>, tensor<4096x32x128xf32>) 
        outs(%2 : tensor<1x32x128xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %6 = arith.mulf %in, %in_2 : f32
      %7 = arith.addf %5, %out : f32
      linalg.yield %7 : f32
    } -> tensor<1x32x128xf32>
    return %5 : tensor<1x32x128xf32>
  }
}

These ops in a single dispatch get past the tile and fuse issue. This will get us into a tile granularity. This will go through the generic reduction pipeline by default. At this point we can look at the performance of the code (it should get vectorized if these are the shapes), or we can use micro-kernels to get performance.

In terms of proposing a new operation, the advantage is that it will allow us to do fusion by construction. One potential representation that I toyed with would be based of the version above with tensor.collapse_shape.

%5 = iree_linalgext.quantized_matmul
    activation_dequantization_operands(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>)
    dequantization_collapse_shape [[0], [1, 2]]
    weight(%arg1 : tensor<1x4096xf32>)
    activation_dequantization: {
    ^bb0(%b0 : i8, %b1 : f32, %b2 : f32):
      %6 = arith.extui %b0 : i8 to i32
      %7 = arith.uitofp %6 : i32 to f32
      %8 = arith.subf %7, %b1 : f32
      %9 = arith.mulf %7, %b2 : f32
      iree_linalgext.yield %9 : f32
    } -> tensor<1x4096xf32>

but I am not sure this is a final state, but more a WIP.

So I think the most immediate next step would be start with a flow.dispatch.region that contains both these operations so that they come fused with it. That can shake out some issues till we get to "post-tiled" part. I can start looking into this part, unless someone beats me to the punch. I have a few reviews to take care of, so was going to mostly look at it tomorrow.

MaheshRavishankar commented 1 year ago

cc @qedawkins @hanhanW and @bjacob . @stellaraccident as well to check the reasoning here and provide feedback.

qedawkins commented 1 year ago

This is a bit hard to work though cause it is not in a very canonical form. Getting this to a canonical form will need

* Dropping the `0` values in affine map for the first generic that
  comes from a spurios unit-extent dims for `%cst` and `%cst_0`
  adjusting the tensor dimensions accordingly

+1, this seems like a sensible change to me.

In terms of proposing a new operation, the advantage is that it will allow us to do fusion by construction. One potential representation that I toyed with would be based of the version above with tensor.collapse_shape.

%5 = iree_linalgext.quantized_matmul
    activation_dequantization_operands(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>)
    dequantization_collapse_shape [[0], [1, 2]]
    weight(%arg1 : tensor<1x4096xf32>)
    activation_dequantization: {
    ^bb0(%b0 : i8, %b1 : f32, %b2 : f32):
      %6 = arith.extui %b0 : i8 to i32
      %7 = arith.uitofp %6 : i32 to f32
      %8 = arith.subf %7, %b1 : f32
      %9 = arith.mulf %7, %b2 : f32
      iree_linalgext.yield %9 : f32
    } -> tensor<1x4096xf32>

but I am not sure this is a final state, but more a WIP.

One issue with this representation is that we would still want to tile the op based on the group size, but this representation seems to omit the implied expand_shape on the LHS to a 1x32x128xf32 which I assume would make it awkward to implement the tiling interface.

So I think the most immediate next step would be start with a flow.dispatch.region that contains both these operations so that they come fused with it. That can shake out some issues till we get to "post-tiled" part. I can start looking into this part, unless someone beats me to the punch. I have a few reviews to take care of, so was going to mostly look at it tomorrow.

Yep, I think this was the plan of action at one point, but after meeting the conclusion seemed to be to try an op. Moving forward with this as a first step makes sense to me. @max191 can help with pre-forming the dispatches if you're able to help with shaking out the issues that come from that.

qedawkins commented 1 year ago

One issue with this representation is that we would still want to tile the op based on the group size, but this representation seems to omit the implied expand_shape on the LHS to a 1x32x128xf32 which I assume would make it awkward to implement the tiling interface.

Actually this might be fine if we disallow tiling along the collapsed dimension, which would work fine for the first level of tile + fuse (+ distribute) because the collapsed dimensions are being reduced anyway.

MaheshRavishankar commented 1 year ago

I started looking into this. I was able to use this as the input for pushing this through.

func.func @something(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x4096xf32>) -> tensor<1x32x128xf32> {
  %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
  %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
  %cst_1 = arith.constant 0.000000e+00 : f32
  %result = flow.dispatch.region -> (tensor<1x32x128xf32>) {
    %0 = tensor.empty() : tensor<1x32x128xf32>
    %1 = tensor.empty() : tensor<4096x32x128xf32>
    %2 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<1x32x128xf32>) -> tensor<1x32x128xf32>
    %3 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                         affine_map<(d0, d1, d2) -> (d0, d1)>,
                         affine_map<(d0, d1, d2) -> (d0, d1)>,
                         affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
        iterator_types = ["parallel", "parallel", "parallel"]}
        ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>)
        outs(%1 : tensor<4096x32x128xf32>) {
    ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32):
      %5 = arith.extui %in : i8 to i32
      %6 = arith.uitofp %5 : i32 to f32
      %7 = arith.subf %6, %in_3 : f32
      %8 = arith.mulf %7, %in_2 : f32
      linalg.yield %8 : f32
    } -> tensor<4096x32x128xf32>
    %4 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
                         affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>,
                         affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
        iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
        ins(%arg1, %3 : tensor<1x4096xf32>, tensor<4096x32x128xf32>)
        outs(%2 : tensor<1x32x128xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %6 = arith.mulf %in, %in_2 : f32
      %7 = arith.addf %6, %out : f32
      linalg.yield %7 : f32
    } -> tensor<1x32x128xf32>
    flow.return %4 : tensor<1x32x128xf32>
    }
  return %result : tensor<1x32x128xf32>
}

This hit a few issues that we need to address. I created https://github.com/openxla/iree/tree/quantized_matmul branch to have a place where we can make progress on all issues hit independently. The following tasklist captures what I have found so far

Flow level issues

After these tasks the dispatch gets to the backend as expected.

LLVMCPU backend issues.

qedawkins commented 1 year ago
func.func @something(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x4096xf32>) -> tensor<1x32x128xf32> {
  %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
  %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
  %cst_1 = arith.constant 0.000000e+00 : f32
  %result = flow.dispatch.region -> (tensor<1x32x128xf32>) {
    %0 = tensor.empty() : tensor<1x32x128xf32>
    %1 = tensor.empty() : tensor<4096x32x128xf32>
    %2 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<1x32x128xf32>) -> tensor<1x32x128xf32>
    %3 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                         affine_map<(d0, d1, d2) -> (d0, d1)>,
                         affine_map<(d0, d1, d2) -> (d0, d1)>,
                         affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
        iterator_types = ["parallel", "parallel", "parallel"]}
        ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>)
        outs(%1 : tensor<4096x32x128xf32>) {
    ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32):
      %5 = arith.extui %in : i8 to i32
      %6 = arith.uitofp %5 : i32 to f32
      %7 = arith.subf %6, %in_3 : f32
      %8 = arith.mulf %7, %in_2 : f32
      linalg.yield %8 : f32
    } -> tensor<4096x32x128xf32>
    %4 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
                         affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>,
                         affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
        iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
        ins(%arg1, %3 : tensor<1x4096xf32>, tensor<4096x32x128xf32>)
        outs(%2 : tensor<1x32x128xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %6 = arith.mulf %in, %in_2 : f32
      %7 = arith.addf %6, %out : f32
      linalg.yield %7 : f32
    } -> tensor<1x32x128xf32>
    flow.return %4 : tensor<1x32x128xf32>
    }
  return %result : tensor<1x32x128xf32>
}

Quick note, somewhere along the way I think you swapped the LHS and output types. The contracting dimension should be the one that is grouped in this IR. It was somewhat unclear with the initial IR, but the matmul is transposed as a part of the quantization (I'm guessing that's where the confusion came from).

MaheshRavishankar commented 1 year ago

https://github.com/openxla/iree/pull/14423 is the majority of work needed to fix LinalgFoldUnitDims pass. There are some IREE specific changes that need to happen as well, but those are a small follow up.

MaheshRavishankar commented 1 year ago

Update :

With these two the only thing remaining now is the backend work.

MaheshRavishankar commented 1 year ago

Update:

14394 has everything needed to treat the dequantized + generic op as a pass through.

Trying out the example, I think the next issue being hit is actually the upstream vectorizer does not know how to deal with operations that have a two reduction iterator types. @qedawkins do you think this is something you guys can pick up. Next week is a bit tight for me but just needs an upstream fix.

Ill update the task list

qedawkins commented 1 year ago

Update: #14394 has everything needed to treat the dequantized + generic op as a pass through. Trying out the example, I think the next issue being hit is actually the upstream vectorizer does not know how to deal with operations that have a two reduction iterator types. @qedawkins do you think this is something you guys can pick up. Next week is a bit tight for me but just needs an upstream fix.

Sure, we can handle the upstream fix if needed. Will try to sync offline about details.

qedawkins commented 1 year ago

@MaheshRavishankar The model input for a single layer looks like the following

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
module {
  func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> {
    %0 = tensor.empty() : tensor<4096x32x1xf32>
    %1 = tensor.empty() : tensor<4096x32x1xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %2 = tensor.empty() : tensor<1x1x4096xf32>
    %3 = tensor.empty() : tensor<4096x32x128xf32>
    %4 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %0, %1 : tensor<4096x32x128xi8>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%3 : tensor<4096x32x128xf32>) {
    ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
      %7 = arith.extui %in : i8 to i32
      %8 = arith.uitofp %7 : i32 to f32
      %9 = arith.subf %8, %in_1 : f32
      %10 = arith.mulf %9, %in_0 : f32
      linalg.yield %10 : f32
    } -> tensor<4096x32x128xf32>
    %6 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %5 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%4 : tensor<1x1x4096xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %7 = arith.mulf %in, %in_0 : f32
      %8 = arith.addf %7, %out : f32
      linalg.yield %8 : f32
    } -> tensor<1x1x4096xf32>
    return %6 : tensor<1x1x4096xf32>
  }
}

We've been working with the IR after collapseUnitExtentDims (to drop the extra unit dimensions on the scales/zero points) so our starting point IR looks like this, which ends up as a matvec.

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map4 = affine_map<(d0, d1, d2) -> (d0)>
module {
  func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> {
    %0 = tensor.empty() : tensor<4096x32xf32>
    %1 = tensor.empty() : tensor<4096x32xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %collapsed = tensor.collapse_shape %arg1 [[0, 1, 2], [3]] : tensor<1x1x32x128xf32> into tensor<32x128xf32>
    %2 = tensor.empty() : tensor<4096xf32>
    %3 = tensor.empty() : tensor<4096x32x128xf32>
    %4 = linalg.fill ins(%cst : f32) outs(%2 : tensor<4096xf32>) -> tensor<4096xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %0, %1 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%3 : tensor<4096x32x128xf32>) {
    ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
      %7 = arith.extui %in : i8 to i32
      %8 = arith.uitofp %7 : i32 to f32
      %9 = arith.subf %8, %in_1 : f32
      %10 = arith.mulf %9, %in_0 : f32
      linalg.yield %10 : f32
    } -> tensor<4096x32x128xf32>
    %6 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "reduction", "reduction"]} ins(%collapsed, %5 : tensor<32x128xf32>, tensor<4096x32x128xf32>) outs(%4 : tensor<4096xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %7 = arith.mulf %in, %in_0 : f32
      %8 = arith.addf %7, %out : f32
      linalg.yield %8 : f32
    } -> tensor<4096xf32>
    %expanded = tensor.expand_shape %6 [[0, 1, 2]] : tensor<4096xf32> into tensor<1x1x4096xf32>
    return %expanded : tensor<1x1x4096xf32>
  }
}
MaheshRavishankar commented 1 year ago

@MaheshRavishankar The model input for a single layer looks like the following

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
module {
  func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> {
    %0 = tensor.empty() : tensor<4096x32x1xf32>
    %1 = tensor.empty() : tensor<4096x32x1xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %2 = tensor.empty() : tensor<1x1x4096xf32>
    %3 = tensor.empty() : tensor<4096x32x128xf32>
    %4 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %0, %1 : tensor<4096x32x128xi8>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%3 : tensor<4096x32x128xf32>) {
    ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
      %7 = arith.extui %in : i8 to i32
      %8 = arith.uitofp %7 : i32 to f32
      %9 = arith.subf %8, %in_1 : f32
      %10 = arith.mulf %9, %in_0 : f32
      linalg.yield %10 : f32
    } -> tensor<4096x32x128xf32>
    %6 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %5 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%4 : tensor<1x1x4096xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %7 = arith.mulf %in, %in_0 : f32
      %8 = arith.addf %7, %out : f32
      linalg.yield %8 : f32
    } -> tensor<1x1x4096xf32>
    return %6 : tensor<1x1x4096xf32>
  }
}

Ok, This IR makes sense. I again dropped the 0 in #map1 above and I used this within a flow.dispatch.region.

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
module {
  func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> {
    %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
    %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
    %0 = flow.dispatch.region -> (tensor<1x1x4096xf32>) {
      %cst_1 = arith.constant 0.000000e+00 : f32
      %1 = tensor.empty() : tensor<1x1x4096xf32>
      %2 = tensor.empty() : tensor<4096x32x128xf32>
      %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
      %4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>\
, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%2 : tensor<4096x32x128xf32>) {
      ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32):
        %6 = arith.extui %in : i8 to i32
        %7 = arith.uitofp %6 : i32 to f32
        %8 = arith.subf %7, %in_3 : f32
        %9 = arith.mulf %8, %in_2 : f32
        linalg.yield %9 : f32
      } -> tensor<4096x32x128xf32>
      %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %4 : tensor<1x1x\
32x128xf32>, tensor<4096x32x128xf32>) outs(%3 : tensor<1x1x4096xf32>) {
      ^bb0(%in: f32, %in_2: f32, %out: f32):
        %6 = arith.mulf %in, %in_2 : f32
        %7 = arith.addf %6, %out : f32
        linalg.yield %7 : f32
      } -> tensor<1x1x4096xf32>
      flow.return %5 : tensor<1x1x4096xf32>
    }
    return %0 : tensor<1x1x4096xf32>
  }
}

This goes through most of the compilation and has no errors ,but some of the configurations picked are off so we end up with large vectors. If this is fine, I can look further into seeing whats happening here.

We've been working with the IR after collapseUnitExtentDims (to drop the extra unit dimensions on the scales/zero points) so our starting point IR looks like this, which ends up as a matvec.

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map4 = affine_map<(d0, d1, d2) -> (d0)>
module {
  func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> {
    %0 = tensor.empty() : tensor<4096x32xf32>
    %1 = tensor.empty() : tensor<4096x32xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %collapsed = tensor.collapse_shape %arg1 [[0, 1, 2], [3]] : tensor<1x1x32x128xf32> into tensor<32x128xf32>
    %2 = tensor.empty() : tensor<4096xf32>
    %3 = tensor.empty() : tensor<4096x32x128xf32>
    %4 = linalg.fill ins(%cst : f32) outs(%2 : tensor<4096xf32>) -> tensor<4096xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %0, %1 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%3 : tensor<4096x32x128xf32>) {
    ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
      %7 = arith.extui %in : i8 to i32
      %8 = arith.uitofp %7 : i32 to f32
      %9 = arith.subf %8, %in_1 : f32
      %10 = arith.mulf %9, %in_0 : f32
      linalg.yield %10 : f32
    } -> tensor<4096x32x128xf32>
    %6 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "reduction", "reduction"]} ins(%collapsed, %5 : tensor<32x128xf32>, tensor<4096x32x128xf32>) outs(%4 : tensor<4096xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %7 = arith.mulf %in, %in_0 : f32
      %8 = arith.addf %7, %out : f32
      linalg.yield %8 : f32
    } -> tensor<4096xf32>
    %expanded = tensor.expand_shape %6 [[0, 1, 2]] : tensor<4096xf32> into tensor<1x1x4096xf32>
    return %expanded : tensor<1x1x4096xf32>
  }
}

I am not sure about this IR. I know why the collapse happens, but I would not start with this and use the IR I posed above in this comment as a starting point.

Max191 commented 1 year ago

I have been running into a segfault during iree-benchmark-module with a slightly different dequantization + matmul from the same Vicuna model, and it seems to happen when the RematerializeParallelOps pass is removed. I'm working with the following IR

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
module {
  func.func @forward_dispatch_37(%arg0: tensor<4096x86x128xi8>, %arg1: tensor<1x1x86x128xf32>) -> tensor<1x1x4096xf32> {
    %0 = tensor.empty() : tensor<4096x86x1xf32>
    %1 = tensor.empty() : tensor<4096x86x1xf32>
    %2 = tensor.empty() : tensor<1x1x4096xf32>
    %3 = tensor.empty() : tensor<4096x86x128xf32>
    %4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %0, %1 : tensor<4096x86x128xi8>, tensor<4096x86x1xf32>, tensor<4096x86x1xf32>) outs(%3 : tensor<4096x86x128xf32>) {
    ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
      %6 = arith.extui %in : i8 to i32
      %7 = arith.uitofp %6 : i32 to f32
      %8 = arith.subf %7, %in_1 : f32
      %9 = arith.mulf %8, %in_0 : f32
      linalg.yield %9 : f32
    } -> tensor<4096x86x128xf32>
    %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %4 : tensor<1x1x86x128xf32>, tensor<4096x86x128xf32>) outs(%2 : tensor<1x1x4096xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %6 = arith.mulf %in, %in_0 : f32
      %7 = arith.addf %6, %out : f32
      linalg.yield %7 : f32
    } -> tensor<1x1x4096xf32>
    return %5 : tensor<1x1x4096xf32>
  }
}

And the same IR with the dequantization and matmul fused:

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
module {
  func.func @forward_dispatch_37(%arg0: tensor<4096x86x128xi8>, %arg1: tensor<1x1x86x128xf32>) -> tensor<1x1x4096xf32> {
    %0 = tensor.empty() : tensor<4096x86x1xf32>
    %1 = tensor.empty() : tensor<4096x86x1xf32>
    %2 = tensor.empty() : tensor<1x1x4096xf32>
    %3 = tensor.empty() : tensor<4096x86x128xf32>
    %4 = flow.dispatch.region -> (tensor<1x1x4096xf32>) {
      %5 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %0, %1 : tensor<4096x86x128xi8>, tensor<4096x86x1xf32>, tensor<4096x86x1xf32>) outs(%3 : tensor<4096x86x128xf32>) {
      ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
        %7 = arith.extui %in : i8 to i32
        %8 = arith.uitofp %7 : i32 to f32
        %9 = arith.subf %8, %in_1 : f32
        %10 = arith.mulf %9, %in_0 : f32
        linalg.yield %10 : f32
      } -> tensor<4096x86x128xf32>
      %6 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %5 : tensor<1x1x86x128xf32>, tensor<4096x86x128xf32>) outs(%2 : tensor<1x1x4096xf32>) {
      ^bb0(%in: f32, %in_0: f32, %out: f32):
        %7 = arith.mulf %in, %in_0 : f32
        %8 = arith.addf %7, %out : f32
        linalg.yield %8 : f32
      } -> tensor<1x1x4096xf32>
      flow.return %6 : tensor<1x1x4096xf32>
    }
    return %4 : tensor<1x1x4096xf32>
  }
}

Compiling both with:

iree-compile --iree-llvmcpu-target-triple=x86_64 \
 --iree-hal-target-backends=llvm-cpu \
 --iree-flow-enable-data-tiling \
 --iree-llvmcpu-enable-microkernels \
 --iree-llvmcpu-stack-allocation-limit=256000 \
 --iree-stream-resource-index-bits=64 \
 --iree-vm-target-index-bits=64 \
 --iree-vm-target-truncate-unsupported-floats \
 dispatch_37_fused.mlir \
 -o dispatch_37_fused.vmfb

The fused IR results in a segfault during iree-benchmark-module, but the unfused IR runs fine, and this segfault only arises after #14453. It also seems that the large vector sizes occur as a result of removing RematerializeParallelOps

The segfault only happens with this dispatch, not with the dispatch we've been discussing so far. The only difference between the dispatches seems to be the shape of the inputs being <4096x86x128xi8>, <1x1x86x128xf32> instead of <4096x32x128xi8>, <1x1x32x128xf32>

@MaheshRavishankar do you have any thoughts on why removing RematerializeParallelOps might cause this segfault at runtime?

MaheshRavishankar commented 1 year ago

I have been looking at this. I know there is a large vector size issue we are having. Trying to figure out what the final solution is. Dont have anything concrete for it right now.

THanks for these. Ill take a look.

Max191 commented 1 year ago

Small update:

I believe the segfault was happening due to the fact the the dequantization op does not get tiled in LLVMCPUTile, and the allocation of the dequantization tile on the stack is large enough to cause a segfault with the larger input shape. The RematerializeParallelOps pass was allowing it to tile along with the matmul, which is why it wasn't having problems with that pass active.

Even with the default configuration path for linalg.generic ops, there is a segfault with this larger input shape. I was able to get the full model to run by decreasing the tile sizes for the linalg.generic contraction ops, but it doesn't seem like a good idea to be setting tile sizes so specifically like this. Also, I think the performance increase from larger tile sizes for the matmul seems to be significant (the best I have done with the restrictions on tile sizes is ~350ms, while I was able to get ~250ms with RematerializeParallelOps and just default tile sizes). With this in mind, does it seem like a good idea to try to get the dequantization to tile along with the matmul in LLVMCPUTile?

benvanik commented 1 year ago

ah yeah, generally if you find yourself ever needing to add a flag like iree-llvmcpu-stack-allocation-limit it means something is wrong and results are likely to cause explosions

MaheshRavishankar commented 1 year ago

Small update:

I believe the segfault was happening due to the fact the the dequantization op does not get tiled in LLVMCPUTile, and the allocation of the dequantization tile on the stack is large enough to cause a segfault with the larger input shape. The RematerializeParallelOps pass was allowing it to tile along with the matmul, which is why it wasn't having problems with that pass active.

Even with the default configuration path for linalg.generic ops, there is a segfault with this larger input shape. I was able to get the full model to run by decreasing the tile sizes for the linalg.generic contraction ops, but it doesn't seem like a good idea to be setting tile sizes so specifically like this. Also, I think the performance increase from larger tile sizes for the matmul seems to be significant (the best I have done with the restrictions on tile sizes is ~350ms, while I was able to get ~250ms with RematerializeParallelOps and just default tile sizes). With this in mind, does it seem like a good idea to try to get the dequantization to tile along with the matmul in LLVMCPUTile?

Thats exactly what I am looking into.

MaheshRavishankar commented 1 year ago

Geez! I should have looked at the indexing maps more carefully (since my priors on what gets fused does not hold anymore due to the custom dispatch formation).

The problem is this

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
...
      %4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%2 : tensor<4096x32x128xf32>) {
      ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32):
        ... 
      } -> tensor<4096x32x128xf32>
      %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %4 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%3 : tensor<1x1x4096xf32>) {
      ^bb0(%in: f32, %in_2: f32, %out: f32):
        ...
      } -> tensor<1x1x4096xf32>
      flow.return %5 : tensor<1x1x4096xf32>
    }
    return %0 : tensor<1x1x4096xf32>
  }
}

The %4 (aka producer) accessed in %5 (aka consumer) using the following map affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> . Since we start tile and fuse from the consumer. That has a tile size of 32 for the loop d2 in %5, which is the innermost parallel loop in the consumer. In the producer though this corresponds to the outermost parallel loop. So 1) There is a bug in how the configuration "propagation" works. It doesnt account for such a transposition cause by construction we dont create such dispatches at Flow. But this is a custom dispatch. 2) Even if we fixed (1), The tile size for the producer would be 32 for the outer most loop and that would result in the tile of the producer needed to compute a tile of the producer at a size of 32x32x128xf32. That obviously is too high.

I need to take go a bit deeper to understand how such a sequence of ops can be code-generated. One question though here is whether this is doing matmul(A, B) or matmul(transpose(A), B) or matmul(A, transpose(B)). I need to look with fresh set of eyes to figure this out again.

MaheshRavishankar commented 1 year ago

Ok, That is indeed the issue. This is doing matmul(A, transpose(B)). Thinking about this a little bit more and seeing the comments above, rematerialization does seem to give better benefits, but note that there is a trade-off here. The rematerialization pass introduced redundant computation to avoid materializing the buffer, but if that is better here, then that is the case. I see two options here 1) If we do have a mechanism to create a dispatch for the dequantized operation and the matmul operation. Then we can use rematerialization right there. This is essentially user assertion that this rematerialization is OK. 2) We can let the compiler automatically decide to rematerialize. #14453 does not actually seem to have an impact on our tracked benchmarks. I am happy revert that patch, but it is a bit of a yolo pass that rematerializes without any consideration.

@qedawkins and @Max191 let me know what you guys think.

qedawkins commented 1 year ago

Ok, That is indeed the issue. This is doing matmul(A, transpose(B)). Thinking about this a little bit more and seeing the comments above, rematerialization does seem to give better benefits, but note that there is a trade-off here. The rematerialization pass introduced redundant computation to avoid materializing the buffer, but if that is better here, then that is the case. I see two options here

1. If we do have a mechanism to create a dispatch for the dequantized operation and the matmul operation. Then we can use rematerialization right there. This is essentially user assertion that this rematerialization is OK.

2. We can let the compiler automatically decide to rematerialize. [Remove `RematerializeParallelOps` pass. #14453](https://github.com/openxla/iree/pull/14453) does not actually seem to have an impact on our tracked benchmarks. I am happy revert that patch, but it is a bit of a yolo pass that rematerializes without any consideration.

@qedawkins and @Max191 let me know what you guys think.

Let's go with 1. That seems like it offers the best flexibility in the short term as I'm not sure we'll always want to rematerialize for SPIR-V anyway. Dematerialization seems like a much harder decision to make than rematerialization, so I agree that the yolo rematerialization in #14453 is something to avoid.

qedawkins commented 1 year ago

https://github.com/openxla/iree/pull/14521 is one way to do it by reintroducing the pass as preprocessing on pre-formed dispatches. I think it would be nice to have the pattern live somewhere in the interim while coming up with better compiler heuristics for rematerialization.

Max191 commented 1 year ago

14521 is one way to do it by reintroducing the pass as preprocessing on pre-formed dispatches. I think it would be nice to have the pattern live somewhere in the interim while coming up with better compiler heuristics for rematerialization.

This seems like it could be a good way to do it to me for now. The only issue then would be setting the lowering configuration, since it is hard to detect the matmul as a contraction when it has been rematerialized.

We could also set the lowering configuration in preprocessing when we fuse the ops, but I would think it is preferable to be setting it in KernelDispatch.cpp, and I have already written a config there for linalg.generic contraction ops. Maybe it would be better to just annotate the dispatch when it is formed, and only rematerialize later on dispatches with the specified annotation after the lowering config has been set.

MaheshRavishankar commented 1 year ago

14521 is one way to do it by reintroducing the pass as preprocessing on pre-formed dispatches. I think it would be nice to have the pattern live somewhere in the interim while coming up with better compiler heuristics for rematerialization.

This seems like it could be a good way to do it to me for now. The only issue then would be setting the lowering configuration, since it is hard to detect the matmul as a contraction when it has been rematerialized.

We could also set the lowering configuration in preprocessing when we fuse the ops, but I would think it is preferable to be setting it in KernelDispatch.cpp, and I have already written a config there for linalg.generic contraction ops. Maybe it would be better to just annotate the dispatch when it is formed, and only rematerialize later on dispatches with the specified annotation after the lowering config has been set.

I think now we need to figure out what is the lowering configuration we want and make sure the default picked matches what we want. Do we have an idea of what the performance targets here are, or if there are some obvious inefficiencies in the code. Ill wait for Quinn's PR to land, rebase and check again. I am starting out on work to make our threading heuristics better. So this might tie into that.

qedawkins commented 1 year ago

14521 is one way to do it by reintroducing the pass as preprocessing on pre-formed dispatches. I think it would be nice to have the pattern live somewhere in the interim while coming up with better compiler heuristics for rematerialization.

This seems like it could be a good way to do it to me for now. The only issue then would be setting the lowering configuration, since it is hard to detect the matmul as a contraction when it has been rematerialized. We could also set the lowering configuration in preprocessing when we fuse the ops, but I would think it is preferable to be setting it in KernelDispatch.cpp, and I have already written a config there for linalg.generic contraction ops. Maybe it would be better to just annotate the dispatch when it is formed, and only rematerialize later on dispatches with the specified annotation after the lowering config has been set.

I think now we need to figure out what is the lowering configuration we want and make sure the default picked matches what we want. Do we have an idea of what the performance targets here are, or if there are some obvious inefficiencies in the code. Ill wait for Quinn's PR to land, rebase and check again. I am starting out on work to make our threading heuristics better. So this might tie into that.

https://github.com/openxla/iree/pull/14521 has landed after a few hours of waiting on CI. My understanding is the goal is performance parity with https://github.com/ggerganov/llama.cpp, although I don't know if anyone has gathered concrete numbers from that and I can't find any obviously advertised on the site.

MaheshRavishankar commented 1 year ago

14521 is one way to do it by reintroducing the pass as preprocessing on pre-formed dispatches. I think it would be nice to have the pattern live somewhere in the interim while coming up with better compiler heuristics for rematerialization.

This seems like it could be a good way to do it to me for now. The only issue then would be setting the lowering configuration, since it is hard to detect the matmul as a contraction when it has been rematerialized. We could also set the lowering configuration in preprocessing when we fuse the ops, but I would think it is preferable to be setting it in KernelDispatch.cpp, and I have already written a config there for linalg.generic contraction ops. Maybe it would be better to just annotate the dispatch when it is formed, and only rematerialize later on dispatches with the specified annotation after the lowering config has been set.

I think now we need to figure out what is the lowering configuration we want and make sure the default picked matches what we want. Do we have an idea of what the performance targets here are, or if there are some obvious inefficiencies in the code. Ill wait for Quinn's PR to land, rebase and check again. I am starting out on work to make our threading heuristics better. So this might tie into that.

14521 has landed after a few hours of waiting on CI. My understanding is the goal is performance parity with ggerganov/llama.cpp, although I don't know if anyone has gathered concrete numbers from that and I can't find any obviously advertised on the site.

Do we have a MLIR model that we can run and start comparing. There is probably going to be a long list of performance issues.

qedawkins commented 1 year ago

14521 has landed after a few hours of waiting on CI. My understanding is the goal is performance parity with ggerganov/llama.cpp, although I don't know if anyone has gathered concrete numbers from that and I can't find any obviously advertised on the site.

Do we have a MLIR model that we can run and start comparing. There is probably going to be a long list of performance issues.

Yes, I believe that can be found here: https://storage.googleapis.com/shark_tank/vicuna/unsharded/mlir/second_vicuna_int8.mlir (@Max191 can confirm). Also @Max191 is working on getting the Llama.cpp performance results so that we have a starting point.

Max191 commented 1 year ago

Yes, I believe that can be found here: https://storage.googleapis.com/shark_tank/vicuna/unsharded/mlir/second_vicuna_int8.mlir (@Max191 can confirm). Also @Max191 is working on getting the Llama.cpp performance results so that we have a starting point.

Yeah, that's the model I have been working with. I will update with some results from llama.cpp once I have downloaded the model weights

MaheshRavishankar commented 1 year ago

awesome! Lets look at the profile and see whats next here.

Max191 commented 1 year ago

With llama.cpp, token generation is about 115ms/tok, and we're getting about 250ms/tok.

Here are the dispatches and profile: Screenshot_20230801_104446

dispatches.zip

This was compiled with rematerialization happening only on the fused dispatches after the lowering configuration was set. I just added an annotation in preprocessing, skipped the first call of rematerializeParallelOps, and restricted the pass to operate on only the annotated ops.

The tile sizes for the top dispatches (34, 5, 37) are all: tile_sizes = [[0, 0, 128, 0, 0], [1, 1, 32, 0, 0], [0, 0, 0, 1, 16], [0, 0, 0, 0, 0]] Which is based on the config for matmul ops.

benvanik commented 1 year ago

both dispatch 5 and 34 are parallel + reductions which I think we saw being similarly slow in another model - that's good news if this is a common pattern :)

MaheshRavishankar commented 1 year ago

With llama.cpp, token generation is about 115ms/tok, and we're getting about 250ms/tok.

Here are the dispatches and profile: Screenshot_20230801_104446

dispatches.zip

This was compiled with rematerialization happening only on the fused dispatches after the lowering configuration was set. I just added an annotation in preprocessing, skipped the first call of rematerializeParallelOps, and restricted the pass to operate on only the annotated ops.

The tile sizes for the top dispatches (34, 5, 37) are all: tile_sizes = [[0, 0, 128, 0, 0], [1, 1, 32, 0, 0], [0, 0, 0, 1, 16], [0, 0, 0, 0, 0]] Which is based on the config for matmul ops.

Great! I can look deeper once I fix another issue. Might need Benoit's insights to really get to the bottom of this.

Max191 commented 1 year ago

@MaheshRavishankar @hanhanW Here is the packed IR I made for the dispatch we've been looking at thus far (With the scales and zero points now as function arguments):

#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, 0)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3, d4, d6, d7)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d3, d4, d5, d6, d7)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5)>
module {
  func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32x1xf32>, %arg3: tensor<4096x32x1xf32>) -> tensor<1x1x4096xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<1x1x4096xf32>
    %1 = tensor.empty() : tensor<1x1x128x32xf32>
    %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1x128x32xf32>) -> tensor<1x1x128x32xf32>
    %3 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
    %4 = tensor.empty() : tensor<128x32x1x32x1x128xf32>
    %5 = tensor.empty() : tensor<128x32x1x32x1x128xi8>
    %pack = tensor.pack %arg0 inner_dims_pos = [0, 1, 2] inner_tiles = [32, 1, 128] into %5 : tensor<4096x32x128xi8> -> tensor<128x32x1x32x1x128xi8>
    %6 = tensor.empty() : tensor<128x32x1x32x1x1xf32>
    %pack_0 = tensor.pack %arg2 inner_dims_pos = [0, 1, 2] inner_tiles = [32, 1, 1] into %6 : tensor<4096x32x1xf32> -> tensor<128x32x1x32x1x1xf32>
    %7 = tensor.empty() : tensor<128x32x1x32x1x1xf32>
    %pack_1 = tensor.pack %arg3 inner_dims_pos = [0, 1, 2] inner_tiles = [32, 1, 1] into %7 : tensor<4096x32x1xf32> -> tensor<128x32x1x32x1x1xf32>
    %8 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%pack, %pack_0, %pack_1 : tensor<128x32x1x32x1x128xi8>, tensor<128x32x1x32x1x1xf32>, tensor<128x32x1x32x1x1xf32>) outs(%4 : tensor<128x32x1x32x1x128xf32>) {
    ^bb0(%in: i8, %in_3: f32, %in_4: f32, %out: f32):
      %11 = arith.extui %in : i8 to i32
      %12 = arith.uitofp %11 : i32 to f32
      %13 = arith.subf %12, %in_4 : f32
      %14 = arith.mulf %13, %in_3 : f32
      linalg.yield %14 : f32
    } -> tensor<128x32x1x32x1x128xf32>
    %9 = tensor.empty() : tensor<1x1x32x1x1x128xf32>
    %pack_2 = tensor.pack %arg1 inner_dims_pos = [2, 3] inner_tiles = [1, 128] into %9 : tensor<1x1x32x128xf32> -> tensor<1x1x32x1x1x128xf32>
    %10 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel", "reduction", "reduction"]} ins(%pack_2, %8 : tensor<1x1x32x1x1x128xf32>, tensor<128x32x1x32x1x128xf32>) outs(%2 : tensor<1x1x128x32xf32>) {
    ^bb0(%in: f32, %in_3: f32, %out: f32):
      %11 = arith.mulf %in, %in_3 : f32
      %12 = arith.addf %11, %out : f32
      linalg.yield %12 : f32
    } -> tensor<1x1x128x32xf32>
    %unpack = tensor.unpack %10 inner_dims_pos = [2] inner_tiles = [32] into %3 : tensor<1x1x128x32xf32> -> tensor<1x1x4096xf32>
    return %unpack : tensor<1x1x4096xf32>
  }
}

I made this by running the original IR through a transform script I made that packs the matmul op with transform.structured.pack and bubbles the pack up through the dequant with transform.apply_patterns.iree.bubble_pack_unpack. The result didn't match what I expected it should look like, so I made some changes to the resulting IR to get the IR above.

I will also share the direct output of my transform script to show what changes I made:

#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0, d3, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3, d4, d6, d7)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d3, d4, d5, d6, d7)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5)>
module {
  func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32x1xf32>, %arg3: tensor<4096x32x1xf32>) -> tensor<1x1x4096xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<1x1x4096xf32>
    %1 = tensor.empty() : tensor<1x1x128x32xf32>
    %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1x128x32xf32>) -> tensor<1x1x128x32xf32>
    %3 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
    %4 = tensor.empty() : tensor<128x32x1x32x1x128xf32>
    %5 = tensor.empty() : tensor<128x32x1x32x1x128xi8>
    %pack = tensor.pack %arg0 inner_dims_pos = [0, 1, 2] inner_tiles = [32, 1, 128] into %5 : tensor<4096x32x128xi8> -> tensor<128x32x1x32x1x128xi8>
    %6 = tensor.empty() : tensor<128x32x1x32x1xf32>
    %pack_0 = tensor.pack %arg2 inner_dims_pos = [0, 1] inner_tiles = [32, 1] into %6 : tensor<4096x32x1xf32> -> tensor<128x32x1x32x1xf32>
    %7 = tensor.empty() : tensor<128x32x1x32x1xf32>
    %pack_1 = tensor.pack %arg3 inner_dims_pos = [0, 1] inner_tiles = [32, 1] into %7 : tensor<4096x32x1xf32> -> tensor<128x32x1x32x1xf32>
    %8 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%pack, %pack_0, %pack_1 : tensor<128x32x1x32x1x128xi8>, tensor<128x32x1x32x1xf32>, tensor<128x32x1x32x1xf32>) outs(%4 : tensor<128x32x1x32x1x128xf32>) {
    ^bb0(%in: i8, %in_3: f32, %in_4: f32, %out: f32):
      %11 = arith.extui %in : i8 to i32
      %12 = arith.uitofp %11 : i32 to f32
      %13 = arith.subf %12, %in_4 : f32
      %14 = arith.mulf %13, %in_3 : f32
      linalg.yield %14 : f32
    } -> tensor<128x32x1x32x1x128xf32>
    %9 = tensor.empty() : tensor<1x1x32x1x1x128xf32>
    %pack_2 = tensor.pack %arg1 inner_dims_pos = [2, 3] inner_tiles = [1, 128] into %9 : tensor<1x1x32x128xf32> -> tensor<1x1x32x1x1x128xf32>
    %10 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel", "reduction", "reduction"]} ins(%pack_2, %8 : tensor<1x1x32x1x1x128xf32>, tensor<128x32x1x32x1x128xf32>) outs(%2 : tensor<1x1x128x32xf32>) {
    ^bb0(%in: f32, %in_3: f32, %out: f32):
      %11 = arith.mulf %in, %in_3 : f32
      %12 = arith.addf %11, %out : f32
      linalg.yield %12 : f32
    } -> tensor<1x1x128x32xf32>
    %unpack = tensor.unpack %10 inner_dims_pos = [2] inner_tiles = [32] into %3 : tensor<1x1x128x32xf32> -> tensor<1x1x4096xf32>
    return %unpack : tensor<1x1x4096xf32>
  }
}
MaheshRavishankar commented 1 year ago

Need to look more, but there is not transpose here... I think that might need some massaging... Thanks for posting this.

MaheshRavishankar commented 1 year ago

cc @bjacob