llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
28.78k stars 11.9k forks source link

[mlir][vector] Add support for linearizing Insert VectorOp in VectorLinearize #92370

Closed akroviakov closed 5 months ago

akroviakov commented 5 months ago

Building on top of #88204, this PR adds support for converting vector.insert into an equivalent vector.shuffle operation that operates on linearized (1-D) vectors.

github-actions[bot] commented 5 months ago

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

llvmbot commented 5 months ago

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes Building on top of [#88204](https://github.com/llvm/llvm-project/pull/88204), this PR adds support for converting `vector.insert` into an equivalent `vector.shuffle` operation that operates on linearized (1-D) vectors. --- Full diff: https://github.com/llvm/llvm-project/pull/92370.diff 2 Files Affected: - (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+99-1) - (modified) mlir/test/Dialect/Vector/linearize.mlir (+29) ``````````diff diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 802a64b0805ee..55d2903d8427d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -44,6 +44,22 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { return true; } +static bool isLessThanOrEqualTargetBitWidth(mlir::Type t, + unsigned targetBitWidth) { + VectorType vecType = dyn_cast(t); + // Reject index since getElementTypeBitWidth will abort for Index types. + if (!vecType || vecType.getElementType().isIndex()) + return false; + // There are no dimension to fold if it is a 0-D vector. + if (vecType.getRank() == 0) + return false; + unsigned trailingVecDimBitWidth = + vecType.getShape().back() * vecType.getElementTypeBitWidth(); + if (trailingVecDimBitWidth > targetBitWidth) + return false; + return true; +} + namespace { struct LinearizeConstant final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -355,6 +371,88 @@ struct LinearizeVectorExtract final return success(); } +private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the InsertOp to a ShuffleOp that works on a +/// linearized vector. +/// Following, +/// vector.insert %source %destination [ position ] +/// is converted to : +/// %source_1d = vector.shape_cast %source +/// %destination_1d = vector.shape_cast %destination +/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d +/// ] %out_nd = vector.shape_cast %out_1d +/// `shuffle_indices_1d` is computed using the position of the original insert. +struct LinearizeVectorInsert final + : public mlir::OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorInsert( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + mlir::LogicalResult + matchAndRewrite(mlir::vector::InsertOp insertOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType()); + assert(!(insertOp.getDestVectorType().isScalable() || + cast(dstTy).isScalable()) && + "scalable vectors are not supported."); + + if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), + targetVectorBitWidth)) + return rewriter.notifyMatchFailure( + insertOp, "Can't flatten since targetBitWidth < OpSize"); + + // dynamic position is not supported + if (insertOp.hasDynamicPosition()) + return rewriter.notifyMatchFailure(insertOp, + "dynamic position is not supported."); + auto srcTy = insertOp.getSourceType(); + auto srcAsVec = mlir::dyn_cast(srcTy); + uint64_t srcSize = 0; + if (srcAsVec) { + srcSize = srcAsVec.getNumElements(); + } else { + return rewriter.notifyMatchFailure(insertOp, + "scalars are not supported."); + } + + auto dstShape = insertOp.getDestVectorType().getShape(); + const auto dstSize = insertOp.getDestVectorType().getNumElements(); + auto dstSizeForOffsets = dstSize; + + // compute linearized offset + int64_t linearizedOffset = 0; + auto offsetsNd = insertOp.getStaticPosition(); + for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { + dstSizeForOffsets /= dstShape[dim]; + linearizedOffset += offset * dstSizeForOffsets; + } + + llvm::SmallVector indices(dstSize); + auto origValsUntil = indices.begin(); + std::advance(origValsUntil, linearizedOffset); + std::iota(indices.begin(), origValsUntil, + 0); // original values that remain [0, offset) + auto newValsUntil = origValsUntil; + std::advance(newValsUntil, srcSize); + std::iota(origValsUntil, newValsUntil, + dstSize); // new values [offset, offset+srcNumElements) + std::iota(newValsUntil, indices.end(), + linearizedOffset + srcSize); // the rest of original values + // [offset+srcNumElements, end) + + rewriter.replaceOpWithNewOp( + insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), + rewriter.getI64ArrayAttr(indices)); + + return mlir::success(); + } + private: unsigned targetVectorBitWidth; }; @@ -410,6 +508,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( : true; }); patterns.add( + LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>( typeConverter, patterns.getContext(), targetBitWidth); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index b29ceab5783d7..31a59b809a74b 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -245,3 +245,32 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32> return %0 : vector<8x2xf32> } + +// ----- +// ALL-LABEL: test_vector_insert +// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> { +func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { + // DEFAULT: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> + // DEFAULT: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> + // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] + // DEFAULT-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + // DEFAULT-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + // DEFAULT-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32> + // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> + // DEFAULT: return %[[RES]] : vector<2x8x4xf32> + + // BW-128: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> + // BW-128: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> + // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] + // BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + // BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + // BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32> + // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> + // BW-128: return %[[RES]] : vector<2x8x4xf32> + + // BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32> + // BW-0: return %[[RES]] : vector<2x8x4xf32> + + %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32> + return %0 : vector<2x8x4xf32> +} ``````````
llvmbot commented 5 months ago

@llvm/pr-subscribers-mlir-vector

Author: Artem Kroviakov (akroviakov)

Changes Building on top of [#88204](https://github.com/llvm/llvm-project/pull/88204), this PR adds support for converting `vector.insert` into an equivalent `vector.shuffle` operation that operates on linearized (1-D) vectors. --- Full diff: https://github.com/llvm/llvm-project/pull/92370.diff 2 Files Affected: - (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+99-1) - (modified) mlir/test/Dialect/Vector/linearize.mlir (+29) ``````````diff diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 802a64b0805ee..55d2903d8427d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -44,6 +44,22 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { return true; } +static bool isLessThanOrEqualTargetBitWidth(mlir::Type t, + unsigned targetBitWidth) { + VectorType vecType = dyn_cast(t); + // Reject index since getElementTypeBitWidth will abort for Index types. + if (!vecType || vecType.getElementType().isIndex()) + return false; + // There are no dimension to fold if it is a 0-D vector. + if (vecType.getRank() == 0) + return false; + unsigned trailingVecDimBitWidth = + vecType.getShape().back() * vecType.getElementTypeBitWidth(); + if (trailingVecDimBitWidth > targetBitWidth) + return false; + return true; +} + namespace { struct LinearizeConstant final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -355,6 +371,88 @@ struct LinearizeVectorExtract final return success(); } +private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the InsertOp to a ShuffleOp that works on a +/// linearized vector. +/// Following, +/// vector.insert %source %destination [ position ] +/// is converted to : +/// %source_1d = vector.shape_cast %source +/// %destination_1d = vector.shape_cast %destination +/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d +/// ] %out_nd = vector.shape_cast %out_1d +/// `shuffle_indices_1d` is computed using the position of the original insert. +struct LinearizeVectorInsert final + : public mlir::OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorInsert( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + mlir::LogicalResult + matchAndRewrite(mlir::vector::InsertOp insertOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType()); + assert(!(insertOp.getDestVectorType().isScalable() || + cast(dstTy).isScalable()) && + "scalable vectors are not supported."); + + if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), + targetVectorBitWidth)) + return rewriter.notifyMatchFailure( + insertOp, "Can't flatten since targetBitWidth < OpSize"); + + // dynamic position is not supported + if (insertOp.hasDynamicPosition()) + return rewriter.notifyMatchFailure(insertOp, + "dynamic position is not supported."); + auto srcTy = insertOp.getSourceType(); + auto srcAsVec = mlir::dyn_cast(srcTy); + uint64_t srcSize = 0; + if (srcAsVec) { + srcSize = srcAsVec.getNumElements(); + } else { + return rewriter.notifyMatchFailure(insertOp, + "scalars are not supported."); + } + + auto dstShape = insertOp.getDestVectorType().getShape(); + const auto dstSize = insertOp.getDestVectorType().getNumElements(); + auto dstSizeForOffsets = dstSize; + + // compute linearized offset + int64_t linearizedOffset = 0; + auto offsetsNd = insertOp.getStaticPosition(); + for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { + dstSizeForOffsets /= dstShape[dim]; + linearizedOffset += offset * dstSizeForOffsets; + } + + llvm::SmallVector indices(dstSize); + auto origValsUntil = indices.begin(); + std::advance(origValsUntil, linearizedOffset); + std::iota(indices.begin(), origValsUntil, + 0); // original values that remain [0, offset) + auto newValsUntil = origValsUntil; + std::advance(newValsUntil, srcSize); + std::iota(origValsUntil, newValsUntil, + dstSize); // new values [offset, offset+srcNumElements) + std::iota(newValsUntil, indices.end(), + linearizedOffset + srcSize); // the rest of original values + // [offset+srcNumElements, end) + + rewriter.replaceOpWithNewOp( + insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), + rewriter.getI64ArrayAttr(indices)); + + return mlir::success(); + } + private: unsigned targetVectorBitWidth; }; @@ -410,6 +508,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( : true; }); patterns.add( + LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>( typeConverter, patterns.getContext(), targetBitWidth); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index b29ceab5783d7..31a59b809a74b 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -245,3 +245,32 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32> return %0 : vector<8x2xf32> } + +// ----- +// ALL-LABEL: test_vector_insert +// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> { +func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { + // DEFAULT: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> + // DEFAULT: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> + // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] + // DEFAULT-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + // DEFAULT-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + // DEFAULT-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32> + // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> + // DEFAULT: return %[[RES]] : vector<2x8x4xf32> + + // BW-128: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> + // BW-128: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> + // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] + // BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + // BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + // BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32> + // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> + // BW-128: return %[[RES]] : vector<2x8x4xf32> + + // BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32> + // BW-0: return %[[RES]] : vector<2x8x4xf32> + + %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32> + return %0 : vector<2x8x4xf32> +} ``````````
github-actions[bot] commented 5 months ago

:white_check_mark: With the latest revision this PR passed the C/C++ code formatter.

Hardcode84 commented 5 months ago

LGTM, but please wait for other reviewers

github-actions[bot] commented 5 months ago

@akroviakov Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!