Open hanhanW opened 7 months ago
E.g., we need make below snippet work. We also need a test with padding values.
func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
%pack = tensor.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32>
return %pack : tensor<4x1x32x16x2xf32>
}
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32>
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x32x16x2xf32>
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32>
// CHECK: return %[[write]] : tensor<4x1x32x16x2xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 : !transform.any_op
transform.yield
}
}
Thanks, @hanhanW, for the description. I can start working on this.
We have direct vectorization patterns for pack/unpack, but we always need to provide input vector sizes. This is not needed for static cases. The upstream vectorizer should be able to handle the case, like other linalg ops. We need to reuse vectorization states in pack/unpack vectorization like https://github.com/llvm/llvm-project/blob/84ae8cb4af9abafe9f45e69744607aadb38d649a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L2023-L2031
Here is the vectorization support added by Max: https://github.com/llvm/llvm-project/pull/78660 We will need to reuse the state manger in the code.