Closed pdhirajkumarprasad closed 1 month ago
I usually add -mlir-print-ir-before-all -mlir-print-ir-after-all -mlir-disable-threading 2> ~/log
to check the failing pass, it shows that it fails in ConvertTorchToLinalg pass. Here is the IR before the pass:
func.func @"torch-jit-export"(%arg0: !torch.vtensor<[35,1],si64>, %arg1: !torch.vtensor<[2,1,200],f32>, %arg2: !torch.vtensor<[2,1,200],f32>, %arg3: !torch.vtensor<[33278,200],f32>) -> !torch.vtensor<[35,1,200],f32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.1"} {
%int-1 = torch.constant.int -1
%int33278 = torch.constant.int 33278
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.lt.Scalar %arg0, %int0 : !torch.vtensor<[35,1],si64>, !torch.int -> !torch.vtensor<[35,1],i1>
%1 = torch.aten.add.Scalar %arg0, %int33278, %int1 : !torch.vtensor<[35,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[35,1],si64>
%2 = torch.aten.where.self %0, %1, %arg0 : !torch.vtensor<[35,1],i1>, !torch.vtensor<[35,1],si64>, !torch.vtensor<[35,1],si64> -> !torch.vtensor<[35,1],si64>
%3 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%4 = torch.aten.view %2, %3 : !torch.vtensor<[35,1],si64>, !torch.list<int> -> !torch.vtensor<[35],si64>
%5 = torch.aten.index_select %arg3, %int0, %4 : !torch.vtensor<[33278,200],f32>, !torch.int, !torch.vtensor<[35],si64> -> !torch.vtensor<[35,200],f32>
%6 = torch.aten.unsqueeze %5, %int1 : !torch.vtensor<[35,200],f32>, !torch.int -> !torch.vtensor<[?,?,?],f32>
%7 = torch.tensor_static_info_cast %6 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor<[35,1,200],f32>
return %7 : !torch.vtensor<[35,1,200],f32>
}
This failure comes from the canonicalization of AtenUnflattenIntOp
. It was introduced sometime back, and canonicalizes the op into a combination of Unsqueeze
and StaticInfoCast
. Unsqueeze
emits dynamic dims in the output, which MLIR does not expect during shape inference of ExpandShape
, which is called in linalg lowering of Unsqueeze
, hence the assertion failure.
Here is the relevant part of the code: https://github.com/vinayakdsci/torch-mlir/blob/99848265c388099f500de9eac235bf0e2c9ccc0d/lib/Dialect/Torch/IR/TorchOps.cpp#L2173. This would require fixing the result type of the Unsqueeze op, to get the correct IR.
What happened?
For given IR
Seeing assertion:
Steps to reproduce your issue
command:
What component(s) does this issue relate to?
Compiler
Version information
No response
Additional context
No response