daquexian / onnx-simplifier

Simplify your onnx model
Apache License 2.0
3.76k stars 379 forks source link

[BUG] Pytorch trilinear upsamling not supported #58

Open copaah opened 4 years ago

copaah commented 4 years ago

Describe the bug I'm trying to use the pytorch interpolate function with the trilinear upsampling method. It can be parsed by the onnx exporter, but the onnx-simplifier gives the following error:

onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Resize node. Name:'' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/upsample.h:281 void onnxruntime::UpsampleBase::ScalesValidation(const std::vector<float>&, onnxruntime::UpsampleMode) const scales.size() == 2 || (scales.size() == 4 && scales[0] == 1 && scales[1] == 1) was false. 'Linear' mode and 'Cubic' mode only support 2-D inputs ('Bilinear', 'Bicubic') or 4-D inputs with the corresponding outermost 2 scale values being 1 in the Resize operator

According to the onnx documentation trilinear upsampling should be supported.

Model A minimal reproducable example:

"""
github_repro_example_trilinear.py
"""
import onnx
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F

class MinimalModel(nn.Module):
    def __init__(self):
        super(MinimalModel, self).__init__()

    def forward(self, input_tensor):
        return F.interpolate(input_tensor, scale_factor=[4, 4, 4], mode='trilinear')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('output_onnx')
    args = parser.parse_args()

    minimal_model = MinimalModel()
    minimal_model = nn.DataParallel(minimal_model)
    minimal_model.cuda()

    # Random deep feature
    input_tensor = torch.rand((1, 8, 20, 128, 128))
    # Check model can do a forward pass
    output_tensor = minimal_model(input_tensor)
    # Check correct size with four times upsampling
    assert output_tensor.shape == (1, 8, 80, 512, 512)

    # Export to onnx
    torch.onnx.export(
        minimal_model.module,
        (input_tensor),
        args.output_onnx,
        export_params=True, verbose=True, training=False, opset_version=11
    )

    original_model = onnx.load(args.output_onnx)
    onnx.checker.check_model(original_model)

Call above script:

python github_repro_example_trilinear.py test.onnx

Invoke onnx-simplifier

python3 -m onnxsim --skip-optimization test.onnx test.onnx.simple

Which will produce the error.

Versions:

$ pip3 show onnx-simplifier
Name: onnx-simplifier
Version: 0.2.4
Summary: Simplify your ONNX model
Home-page: https://github.com/daquexian/onnx-simplifier
Author: daquexian
Author-email: daquexian566@gmail.com
License: Apache
Location: /usr/local/lib/python3.6/dist-packages
Requires: onnx, onnxruntime, protobuf
Required-by: 
lenny@lenny-G
$ pip show torch
Name: torch
Version: 1.4.0
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /home/.../.local/lib/python3.6/site-packages
Requires: 
Required-by: torchvision
$ pip3 show onnx
Name: onnx
Version: 1.6.0
Summary: Open Neural Network Exchange
Home-page: https://github.com/onnx/onnx
Author: bddppq
Author-email: jbai@fb.com
License: UNKNOWN
Location: /home/../.local/lib/python3.6/site-packages
Requires: six, protobuf, numpy, typing-extensions
Required-by: onnxruntime, onnx-simplifier, onnx-tensorrt
$ pip3 show onnxruntime
Name: onnxruntime
Version: 1.2.0
Summary: ONNX Runtime Python bindings
Home-page: UNKNOWN
Author: Microsoft Corporation
Author-email: onnx@microsoft.com
License: MIT License
Location: /usr/local/lib/python3.6/dist-packages
Requires: onnx, numpy
Required-by: onnx-simplifier
daquexian commented 4 years ago

Thanks for your report. Could you please provide your onnxruntime version? Thanks

copaah commented 4 years ago

Thanks for your report. Could you please provide your onnxruntime version? Thanks

Updated my question with onnxruntime.

copaah commented 4 years ago

Any update on this?

daquexian commented 4 years ago

@copaah sorry for the late reply. unfortunately the problem is from onnxruntime. I'm afraid that we can only wait for it to support 5-D resize.