llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
29.13k stars 12.01k forks source link

[linalg] Vectorization failure - masked scalar read + broadcast #116197

Open banach-space opened 2 days ago

banach-space commented 2 days ago

REPRODUCER

func.func @vectorization_test(%extracted_slice : tensor<1x1x3xi32>, %arg0: index, %arg2: index, %3: tensor<2x4xi32>, %4: tensor<1x3x2x4xi32>) -> tensor<1x1x3xi32>{
%c0 = arith.constant 0 :index

%8 = linalg.generic {
  indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
  iterator_types = ["parallel", "parallel", "parallel"]}
  outs(%extracted_slice : tensor<1x1x3xi32>) {
  ^bb0(%out: i32):
    %9 = linalg.index 0 : index
    %extracted = tensor.extract %3[%9, %c0] : tensor<2x4xi32>
    %14 = arith.index_cast %extracted : i32 to index
    %extracted_2 = tensor.extract %4[%c0, %14, %14, %14] : tensor<1x3x2x4xi32>
    linalg.yield %extracted_2 : i32
  } -> tensor<1x1x3xi32>

  return %8 : tensor<1x1x3xi32>
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
    // %2 = transform.structured.vectorize_children_and_apply_patterns %1  { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 vector_sizes [1, 1, 4] {vectorize_nd_extract} : !transform.any_op
    transform.yield
  }
}

ERROR LOG

../file.mlir:10:18: error: 'vector.mask' op expects a 'vector<i1>' mask for the maskable operation
    %extracted = tensor.extract %3[%9, %c0] : tensor<2x4xi32>
                 ^
../file.mlir:10:18: note: see current operation:
%17 = "vector.mask"(%6) ({
  %34 = "vector.transfer_read"(%arg3, %15, %0, %16) <{in_bounds = [true, true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}> : (tensor<2x4xi32>, index, index, i32) -> vector<1x1x4xi32>
  "vector.yield"(%34) : (vector<1x1x4xi32>) -> ()
}) : (vector<1x1x4xi1>) -> vector<1x1x4xi32>

ANALYSIS

  1. Type of vectorization: masked
  2. Op: tensor.extract --> %extracted = tensor.extract %3[%9, %c0] : tensor<2x4xi32>(effectively a scalar read + broadcast).
  3. The Vectorizer output (generated by vectorizeAsTensorExtract):
    %34 = "vector.transfer_read"(%arg3, %15, %0, %16) <{in_bounds = [true, true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}> : (tensor<2x4xi32>, index, index, i32) -> vector<1x1x4xi32>

Looks like the mask generated by the vectorizer doesn't match the "expected mask" computed by MaskOp::verify:

To me, the output from the Vectorizer is correct.

RELEVANT DATA POINT

Looking at this example from "invalid.mlir": https://github.com/llvm/llvm-project/blob/8ff2da782d676edddc19d856a853c1ebab999fc2/mlir/test/Dialect/Vector/invalid.mlir#L475-L483

To me, the error is wrong and the example is in fact correct (as in, vector<3x8x7xi1> as a mask makes sense to me). Specifically, as per the docs (emphasis mine):

The masked-off lanes in the result vector are taken from the corresponding lanes of the pass-thru argument, if provided, or left unmodified, otherwise.

Doesn't this mean that the mask shape should always match the result vector shape?

PROPOSED SOLUTION

CC @dcaballe

llvmbot commented 2 days ago

@llvm/issue-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

**REPRODUCER** ```mlir func.func @vectorization_test(%extracted_slice : tensor<1x1x3xi32>, %arg0: index, %arg2: index, %3: tensor<2x4xi32>, %4: tensor<1x3x2x4xi32>) -> tensor<1x1x3xi32>{ %c0 = arith.constant 0 :index %8 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%extracted_slice : tensor<1x1x3xi32>) { ^bb0(%out: i32): %9 = linalg.index 0 : index %extracted = tensor.extract %3[%9, %c0] : tensor<2x4xi32> %14 = arith.index_cast %extracted : i32 to index %extracted_2 = tensor.extract %4[%c0, %14, %14, %14] : tensor<1x3x2x4xi32> linalg.yield %extracted_2 : i32 } -> tensor<1x1x3xi32> return %8 : tensor<1x1x3xi32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op // %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %0 vector_sizes [1, 1, 4] {vectorize_nd_extract} : !transform.any_op transform.yield } } ``` **ERROR LOG** ```bash ../file.mlir:10:18: error: 'vector.mask' op expects a 'vector<i1>' mask for the maskable operation %extracted = tensor.extract %3[%9, %c0] : tensor<2x4xi32> ^ ../file.mlir:10:18: note: see current operation: %17 = "vector.mask"(%6) ({ %34 = "vector.transfer_read"(%arg3, %15, %0, %16) <{in_bounds = [true, true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}> : (tensor<2x4xi32>, index, index, i32) -> vector<1x1x4xi32> "vector.yield"(%34) : (vector<1x1x4xi32>) -> () }) : (vector<1x1x4xi1>) -> vector<1x1x4xi32> ``` **ANALYSIS** 1. Type of vectorization: **masked** 2. Op: **`tensor.extract`** --> `%extracted = tensor.extract %3[%9, %c0] : tensor<2x4xi32>`(effectively a scalar read + broadcast). 3. The Vectorizer output (generated by `vectorizeAsTensorExtract`): ```mlir %34 = "vector.transfer_read"(%arg3, %15, %0, %16) <{in_bounds = [true, true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}> : (tensor<2x4xi32>, index, index, i32) -> vector<1x1x4xi32> ``` Looks like the mask generated by the vectorizer doesn't match the "expected mask" computed by `MaskOp::verify`: * Linalg vectorizer, when creating `vector.mask`, generates a mask based on the static loop sizes and input vector sizes. This gives: `vector<1x1x4xi1>`. * The Op verifier uses [inferTransferOpMaskType](https://github.com/llvm/llvm-project/blob/d119d43e92333966125755353f4e6227dd2c70da/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L4118-L4129), which has no access to the LinalgOp information and instead looks at the permutation map of the masked op, `vector.transfer_read`. And that yields `vector<i1>` (based on `permutation_map = affine_map<(d0, d1) -> (0, 0, 0)`). To me, the output from the Vectorizer is correct. **RELEVANT DATA POINT** Looking at this example from "invalid.mlir": https://github.com/llvm/llvm-project/blob/8ff2da782d676edddc19d856a853c1ebab999fc2/mlir/test/Dialect/Vector/invalid.mlir#L475-L483 To me, the error is **wrong** and the example is in fact **correct** (as in, `vector<3x8x7xi1>` as a mask makes sense to me). Specifically, as per [the docs](https://mlir.llvm.org/docs/Dialects/Vector/#vectormask-vectormaskop) (emphasis mine): > The masked-off lanes **in the result vector** are taken from the corresponding lanes of the pass-thru argument, if provided, or left unmodified, otherwise. Doesn't this mean that the mask shape should always match the result vector shape? **PROPOSED SOLUTION** * Either fix how the vectorizer handles broadcast dims, or * Update `inferTransferOpMaskType` and, in general, the semantics of broadcast dims when masking is used. CC @dcaballe
dcaballe commented 2 days ago

Thanks for reporting this! Indeed, the result of inferTransferOpMaskType is wrong. The mask for transfer ops is inferred by inversing its permutation map as masking applies to the memory operation side of it, which is a single element in this case. The code inversing the mask is probably not handling this case properly. I think compressUnusedDims might be turning the broadcasting map (e.g., (d0, d1, ..., dn) -> (0, 0, ... 0) into a map with no inputs (e.g., () -> (0, 0, ... 0)) and things get off rail after that.

I can look into this later this week, if needed. Just let me know.

banach-space commented 2 days ago

Thanks for replying!

I can look into this later this week, if needed. Just let me know.

Can you take look at the test that I linked? I think that the current semantics are not 100% correct.

Also, let me share what I've tried already (see diff below). I've updated inferTransferOpMaskType with a pre-processing step so that:

  1. (new!) Replace broadcast dims with identities (d0, d1, ..., dn) -> (0, 0, ... 0) --> (d0, d1, ..., dn) -> (d0, d1, ... dn)
  2. (old) Inverse the map: (d0, d1, ..., dn) -> (d0, d1, ... dn) --> (d0, d1, ..., dn) -> (d0, d1, ... dn).
  3. (old) Finally, proceed with the rest.

This makes sense under the assumption that:

for broadcast dims, the corresponding mask size should be identical to the output vector size.

This would mean that this example, extracted from invalid.mlir, is actually valid:

%mask = vector.splat %c1 : vector<3x8x7xi1> 
%0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref<?x?x?xf32>, vector<3x8x7xf32> 

My update to inferTransferOpMaskType fixes the repro, but ... brakes "invalid.mlir" and other tests. I'm trying to figure out whether it's my assumptions or the tests that are wrong 😅


diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e163..5760c8243261 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4118,7 +4118,21 @@ void TransferReadOp::print(OpAsmPrinter &p) {
 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
                                                  AffineMap permMap) {
   auto i1Type = IntegerType::get(permMap.getContext(), 1);
-  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
+  SmallVector<AffineExpr> newResults;
+  for (auto [idx, expr] : llvm::enumerate(permMap.getResults())) {
+    auto constExpr = dyn_cast<AffineConstantExpr>(expr);
+    if (constExpr && constExpr.getValue() == 0) {
+      auto dim = getAffineDimExpr(idx, vecType.getContext());
+      newResults.push_back(dim);
+    }
+    else
+      newResults.push_back(expr);
+  }
+  AffineMap newMap =
+      AffineMap::get(std::max<int64_t>(permMap.getNumDims(), newResults.size()),
+                     permMap.getNumSymbols(), newResults, vecType.getContext());
+  AffineMap invPermMap = inversePermutation(compressUnusedDims(newMap));
   assert(invPermMap && "Inversed permutation map couldn't be computed");
   SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());