Open axeabc opened 1 week ago
Thanks for filing this issue. I have a couple of thoughts:
A small change to the test to be semantically correct in terms of element count, appears to work fine:
module { func.func @test_reshape_3d(%arg1: tensor<2x3x4xf32>) -> (tensor<6x?x4xf32>, tensor<2x?x12xf32>) { %c0 = arith.constant 0 : index %0 = tosa.reshape %arg1 {new_shape = array<i64: 6, 1, 4>} : (tensor<2x3x4xf32>) -> tensor<6x?x4xf32> %1 = tosa.reshape %arg1 {new_shape = array<i64: 2, 1, 12>} : (tensor<2x3x4xf32>) -> tensor<2x?x12xf32> return %0, %1 : tensor<6x?x4xf32>, tensor<2x?x12xf32> } }
With
/bin/mlir-opt tmp2.mlir --tosa-to-tensor
module { func.func @test_reshape_3d(%arg0: tensor<2x3x4xf32>) -> (tensor<6x?x4xf32>, tensor<2x?x12xf32>) { %c0 = arith.constant 0 : index %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x3x4xf32> into tensor<6x4xf32> %expanded = tensor.expand_shape %collapsed [[0, 1], [2]] output_shape [6, 1, 4] : tensor<6x4xf32> into tensor<6x1x4xf32> %cast = tensor.cast %expanded : tensor<6x1x4xf32> to tensor<6x?x4xf32> %collapsed_0 = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<2x3x4xf32> into tensor<2x12xf32> %expanded_1 = tensor.expand_shape %collapsed_0 [[0, 1], [2]] output_shape [2, 1, 12] : tensor<2x12xf32> into tensor<2x1x12xf32> %cast_2 = tensor.cast %expanded_1 : tensor<2x1x12xf32> to tensor<2x?x12xf32> return %cast, %cast_2 : tensor<6x?x4xf32>, tensor<2x?x12xf32> } }
@sjarus is right; this definitely shouldn't have led to a crash.
The ReshapeOp
verifier doesn't have a check for this case but is rather straight-forward to handle as the new_shape
and the input itself are completely statically defined.
This issue is duplicate with #107969.
Is it correct to add a verifier for tosa.reshape
in this way?
if (inputType.hasStaticShape()) {
int64_t inputElementsNum = inputType.getNumElements();
// Compute the number of elements in the new shape
int64_t newShapeElementsNum = std::accumulate(
getNewShape().begin(), getNewShape().end(), 1LL,
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
// Check if the new shape is fully static
bool isStaticNewShape = llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
// Validate the reshape operation
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
return emitOpError() << "Cannot reshape " << inputElementsNum
<< " elements into " << newShapeElementsNum;
}
}
If it's correct, I'll submit a PR.
This issue is duplicate with #107969. Is it correct to add a verifier for
tosa.reshape
in this way?if (inputType.hasStaticShape()) { int64_t inputElementsNum = inputType.getNumElements(); // Compute the number of elements in the new shape int64_t newShapeElementsNum = std::accumulate( getNewShape().begin(), getNewShape().end(), 1LL, [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; }); // Check if the new shape is fully static bool isStaticNewShape = llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; }); // Validate the reshape operation if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) || (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) { return emitOpError() << "Cannot reshape " << inputElementsNum << " elements into " << newShapeElementsNum; } }
If it's correct, I'll submit a PR.
Thanks @CoTinker for having a look. Please feel free to submit a PR and we can review.
Okay.
git version: 761bf333e378b52614c
system:
Ubuntu 18.04.6 LTS
reproduce with:
mlir-opt -tosa-to-tensor a.mlir
a.mlir:
stack trace: