Open banach-space opened 2 days ago
@llvm/issue-subscribers-mlir
Author: Andrzej Warzyński (banach-space)
Thanks for reporting this! Indeed, the result of inferTransferOpMaskType
is wrong. The mask for transfer ops is inferred by inversing its permutation map as masking applies to the memory operation side of it, which is a single element in this case. The code inversing the mask is probably not handling this case properly. I think compressUnusedDims
might be turning the broadcasting map (e.g., (d0, d1, ..., dn) -> (0, 0, ... 0)
into a map with no inputs (e.g., () -> (0, 0, ... 0)
) and things get off rail after that.
I can look into this later this week, if needed. Just let me know.
Thanks for replying!
I can look into this later this week, if needed. Just let me know.
Can you take look at the test that I linked? I think that the current semantics are not 100% correct.
Also, let me share what I've tried already (see diff below). I've updated inferTransferOpMaskType
with a pre-processing step so that:
(d0, d1, ..., dn) -> (0, 0, ... 0)
--> (d0, d1, ..., dn) -> (d0, d1, ... dn)
(d0, d1, ..., dn) -> (d0, d1, ... dn)
--> (d0, d1, ..., dn) -> (d0, d1, ... dn)
.This makes sense under the assumption that:
for broadcast dims, the corresponding mask size should be identical to the output vector size.
This would mean that this example, extracted from invalid.mlir, is actually valid:
%mask = vector.splat %c1 : vector<3x8x7xi1>
%0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref<?x?x?xf32>, vector<3x8x7xf32>
My update to inferTransferOpMaskType
fixes the repro, but ... brakes "invalid.mlir" and other tests. I'm trying to figure out whether it's my assumptions or the tests that are wrong 😅
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e163..5760c8243261 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4118,7 +4118,21 @@ void TransferReadOp::print(OpAsmPrinter &p) {
VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
- AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
+ SmallVector<AffineExpr> newResults;
+ for (auto [idx, expr] : llvm::enumerate(permMap.getResults())) {
+ auto constExpr = dyn_cast<AffineConstantExpr>(expr);
+ if (constExpr && constExpr.getValue() == 0) {
+ auto dim = getAffineDimExpr(idx, vecType.getContext());
+ newResults.push_back(dim);
+ }
+ else
+ newResults.push_back(expr);
+ }
+ AffineMap newMap =
+ AffineMap::get(std::max<int64_t>(permMap.getNumDims(), newResults.size()),
+ permMap.getNumSymbols(), newResults, vecType.getContext());
+ AffineMap invPermMap = inversePermutation(compressUnusedDims(newMap));
assert(invPermMap && "Inversed permutation map couldn't be computed");
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
REPRODUCER
ERROR LOG
ANALYSIS
tensor.extract
-->%extracted = tensor.extract %3[%9, %c0] : tensor<2x4xi32>
(effectively a scalar read + broadcast).vectorizeAsTensorExtract
):Looks like the mask generated by the vectorizer doesn't match the "expected mask" computed by
MaskOp::verify
:Linalg vectorizer, when creating
vector.mask
, generates a mask based on the static loop sizes and input vector sizes. This gives:vector<1x1x4xi1>
.The Op verifier uses inferTransferOpMaskType, which has no access to the LinalgOp information and instead looks at the permutation map of the masked op,
vector.transfer_read
. And that yieldsvector<i1>
(based onpermutation_map = affine_map<(d0, d1) -> (0, 0, 0)
).To me, the output from the Vectorizer is correct.
RELEVANT DATA POINT
Looking at this example from "invalid.mlir": https://github.com/llvm/llvm-project/blob/8ff2da782d676edddc19d856a853c1ebab999fc2/mlir/test/Dialect/Vector/invalid.mlir#L475-L483
To me, the error is wrong and the example is in fact correct (as in,
vector<3x8x7xi1>
as a mask makes sense to me). Specifically, as per the docs (emphasis mine):Doesn't this mean that the mask shape should always match the result vector shape?
PROPOSED SOLUTION
inferTransferOpMaskType
and, in general, the semantics of broadcast dims when masking is used.CC @dcaballe