nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
92 stars 46 forks source link

onnx.If errors when the two branches yield with different shapes #696

Open renxida opened 4 months ago

renxida commented 4 months ago

I was working on #566

Found the problem in KeypointRCNN (gist with stripped IR at %5503 (the above link takes you to the correct line.

When lowering an If op with two branches returning different types, we encounter:

./repro3.mlir:2:10: error: 'torch.prim.If' op  along control flow edge from Region #0 to parent results: source type #0 '!torch.vtensor<[1],si64>' should match input type #0 '!torch.vtensor<[?],si64>'
    %0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[?],si64> {
         ^
./repro3.mlir:2:10: note: see current operation: 
%2 = "torch.prim.If"(%1) ({
  %3 = "torch.vtensor.literal"() <{value = dense<0> : tensor<1xsi64>}> : () -> !torch.vtensor<[1],si64>
  "torch.prim.If.yield"(%3) : (!torch.vtensor<[1],si64>) -> ()
}, {
  "torch.prim.If.yield"(%arg1) : (!torch.vtensor<[?],si64>) -> ()
}) : (!torch.bool) -> !torch.vtensor<[?],si64>
// -----// IR Dump After ConvertTorchOnnxToTorch Failed (convert-torch-onnx-to-torch) //----- //
"func.func"() <{function_type = (!torch.vtensor<[1],i1>, !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],si64>, sym_name = "minimal_example"}> ({
^bb0(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[?],si64>):
  %0 = "torch.aten.item"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.int
  %1 = "torch.aten.Bool.int"(%0) : (!torch.int) -> !torch.bool
  %2 = "torch.prim.If"(%1) ({
    %3 = "torch.vtensor.literal"() <{value = dense<0> : tensor<1xsi64>}> : () -> !torch.vtensor<[1],si64>
    "torch.prim.If.yield"(%3) : (!torch.vtensor<[1],si64>) -> ()
  }, {
    "torch.prim.If.yield"(%arg1) : (!torch.vtensor<[?],si64>) -> ()
  }) : (!torch.bool) -> !torch.vtensor<[?],si64>
  "func.return"(%2) : (!torch.vtensor<[?],si64>) -> ()
}) {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.13.1"} : () -> ()

Reproducer:

func.func @minimal_example(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.13.1"} {
    %0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[?],si64> {
        torch.operator_terminator %arg1 : !torch.vtensor<[?],si64>
    }, {
        %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
        torch.operator_terminator %1 : !torch.vtensor<[1],si64>
    }
    return %0 : !torch.vtensor<[?],si64>
}
renxida commented 4 months ago

I'm not sure if we want to support this, but if we do I think we need to insert a type conversion from vtensor<[1], si64> to vtensor<[?], si64>. Not sure how to materialize that.

renxida commented 4 months ago

KeypointRCNN_vaiq_int8 has if statements that returns not just different shapes but different ranks. We can't support that and will need something to be done with the model.

renxida commented 3 months ago

Stella mentions a way to deal with similar problems

https://discord.com/channels/973663919757492264/1173330951791706113/1246504997269798945

and @zjgarvey's comment in this morning's meeting got me to put 2 and 2 together. Will try to just remove the small branch of these onnx.If