apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.67k stars 3.45k forks source link

[Bug] [Relay] [ONNX] Incorrect shape inference of `Squeeze` in `DynamicToStatic` #17050

Closed shaoyuyoung closed 2 weeks ago

shaoyuyoung commented 4 months ago

Description

This torch model has only two ops: ReflectionPad3d and squeeze

Firstly, I try to export the torch model to onnx model. Then I get the below. model onnx

Onnx does its unique operation on the model. We can find that this is a dynamic graph which contains if branch structure because of the squeeze operator.

ONNX thinks this model is valid. However, When I used relay to convert the model, I met shape mismatch error. The correct shape should be Tensor[(13, 1, 1, 1), float32] but TVM got Tensor[(13, 13, 1, 1), float32].

(I think maybe) TVM has some bugs in the DynamicToStatic :(

Code

import onnx
import torch
import torch.nn as nn
import torch.onnx
from tvm import relay, relax

def get_onnx_shape(onnx_model):
    input_shapes = {}
    for input in onnx_model.graph.input:
        shape = []
        for dim in input.type.tensor_type.shape.dim:
            if dim.dim_value > 0:
                shape.append(dim.dim_value)
            else:
                shape.append(1)

        input_shapes[input.name] = shape
    return input_shapes

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.pad = nn.ReflectionPad3d((0, 0, -43, 0, 0, -46))

    def forward(self, x):
        x = self.pad(x)
        x = torch.squeeze(x, dim=1)
        return x

model = Model()

input_tensor = torch.randn(13, 47, 44, 1)

onnx_file_path = "model.onnx"
torch.onnx.export(model,
                  input_tensor,
                  onnx_file_path,
                  export_params=True,
                  opset_version=14,
                  do_constant_folding=False,
                  input_names=['input'],
                  output_names=['output']
                  )

onnx_model = onnx.load("model.onnx")
shape_dict = get_onnx_shape(onnx_model)

mod, params = relay.frontend.from_onnx(
    onnx_model, shape_dict, freeze_params=True
)

Error Log

click to see the error log ``` TVMError: Traceback (most recent call last): 20: tvm::runtime::PackedFuncObj::Extractor::AssignTypedLambda(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string, std::allocator >, tvm::runtime::TVMRetValue) 19: tvm::transform::Pass::operator()(tvm::IRModule) const 18: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 17: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 16: _ZN3tvm7runtime13PackedFun 15: tvm::runtime::TypedPackedFunc::AssignTypedLambda(tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const 14: tvm::relay::DynamicToStatic(tvm::relay::Function, tvm::IRModule) 13: tvm::relay::DynamicToStaticMutator::PrepareInput(tvm::RelayExpr const&) 12: tvm::transform::Pass::operator()(tvm::IRModule) const 11: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 10: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 9: tvm::transform::Pass::operator()(tvm::IRModule) const 8: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 7: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 6: _ZN3tvm7runtime13PackedFun 5: tvm::runtime::TypedPackedFunc::AssignTypedLambda(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const 4: tvm::DiagnosticContext::Render() 3: tvm::DiagnosticRenderer::Render(tvm::DiagnosticContext const&) 2: tvm::runtime::PackedFuncObj::Extractor::AssignTypedLambda(tvm::TerminalRenderer(std::ostream&)::{lambda(tvm::DiagnosticContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) 1: tvm::ReportAt(tvm::DiagnosticContext const&, std::ostream&, tvm::Span const&, tvm::Diagnostic const&) 0: _ZN3tvm7runtime6deta File "/workspace/tvm/src/ir/diagnostic.cc", line 264 TVMError: The source maps are not populated for this module. Please use `tvm.relay.transform.AnnotateSpans` to attach source maps for error reporting. Error: The Relay type checker is unable to show the following types match: Tensor[(13, 13, 1, 1), float32] Tensor[(13, 1, 1, 1), float32] In particular: dimension 1 conflicts: 13 does not match 1. ```

Environment

TVM d1ac1c0202b3d8cb2af268ce79c2ac710554152b ubuntu 20

cc @KJlaccHoeUM9l @shingjan

xhmelon commented 2 weeks ago

This issue has been fixed by https://github.com/apache/tvm/pull/17383 .