nod-ai / SHARK-Turbine

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
91 stars 45 forks source link

failed to legalize unresolved materialization from ('!torch.vtensor<[0],si64>') to '!torch.vtensor<[?],si64>' that remained live after conversion #826

Open pdhirajkumarprasad opened 1 week ago

pdhirajkumarprasad commented 1 week ago

For the given IR, getting error as

model.mlir:3:12: error: failed to legalize unresolved materialization from ('!torch.vtensor<[0],si64>') to '!torch.vtensor<[?],si64>' that remained live after conversion

    %256 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[],si64>) -> !torch.vtensor<[?],si64> 

           ^

model.mlir:3:12: note: see current operation: %2 = "builtin.unrealized_conversion_cast"(%0) : (!torch.vtensor<[0],si64>) -> !torch.vtensor<[?],si64>
module {
  func.func @main_graph(%arg0: !torch.vtensor<[],si64> ) -> !torch.vtensor<[?],si64>  attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
    %256 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[],si64>) -> !torch.vtensor<[?],si64> 
    return %256: !torch.vtensor<[?],si64>
  }
}

command:

iree-compile --iree-hal-target-backends=llvm-cpu model.mlir
jinchen62 commented 2 days ago

@pdhirajkumarprasad Do you know how I can get and run these failed models? https://github.com/pdhirajkumarprasad/SHARK-TestSuite/blob/feature/qa/issue/onnx-to-torch/unresolved_materialization I would like to see a bigger IR.

pdhirajkumarprasad commented 2 days ago

Download below scripts: run.sh.txt upgrade_onnx.py.txt

Rename to run.sh/upgrade_onnx.py(by removing the .txt) extension

Save the model in file say modelList

Run: ./run.sh modelList

It will save the logs in ./temp directory

vinayakdsci commented 1 day ago

@pdhirajkumarprasad @jinchen62 The IR is inherently wrong, but this error message is generated from a (possibly) faulty lowering. Shape should not go for an empty tensor(scalar) as there is not point for that.

This (rough) patch should fix this error and produce a correct one from MLIR:

diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
index 68868e95..4fb5aa53 100644
--- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
+++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
@@ -1659,6 +1659,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
         auto inputType = dyn_cast<Torch::ValueTensorType>(operand.getType());
         int64_t inputRank = inputType.getSizes().size();

+        if (start == 0 && end == -1) {
+          rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
+              binder.op, resultType, operand);
+          return success();
+        }
+
         auto shapeType = Torch::ValueTensorType::get(
             binder.op->getContext(), SmallVector<int64_t>{inputRank},
             resultType.getOptionalDtype());
@@ -1666,11 +1672,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
         Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>(
             binder.getLoc(), shapeType, operand);

-        if (start == 0 && end == -1) {
-          rewriter.replaceOp(binder.op, shape);
-          return success();
-        }
-
         Value sv = rewriter.create<Torch::ConstantIntOp>(
             binder.getLoc(), rewriter.getI64IntegerAttr(start));

@@ -1681,10 +1682,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

         Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);

-        shape = rewriter.create<Torch::AtenSliceTensorOp>(
-            binder.getLoc(), resultType, shape, dim, sv, ev, step);
+        rewriter.replaceOpWithNewOp<Torch::AtenSliceTensorOp>(
+            binder.op, resultType, shape, dim, sv, ev, step);

-        rewriter.replaceOp(binder.op, shape);
         return success();
       });

I think we should handle scalar inputs and error out more gracefully in such cases.

vinayakdsci commented 1 day ago

When we returned in the first if condition, the resultType was not consumed, hence the materialization error.