schyun9212 / maskrcnn-benchmark

Converting maskrcnn-benchmark model to TorchScript or ONNX
MIT License
2 stars 0 forks source link

INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(float)' of input parameter (44) of operator (Equal) in node () is invalid. #10

Closed schyun9212 closed 4 years ago

schyun9212 commented 4 years ago

🐛 Bug

I met this error during testing feature extractor in ROI module.

To Reproduce

import torch
import io
import unittest

from maskrcnn_benchmark.structures.image_list import ImageList

from demo.unittest.onnx.export import ONNXExportTester, ONNX_OPSET_VERSION, VALIDATION_TYPE, cfg, coco_demo, sample_features, sample_proposals, t_width, t_height

class FeatureExtractorTester(ONNXExportTester):
    def test_feature_extractor(self):
        from maskrcnn_benchmark.structures.bounding_box import BoxList
        from maskrcnn_benchmark.modeling.roi_heads.box_head.roi_box_feature_extractors import make_roi_box_feature_extractor

        class FeatureExtractor(torch.nn.Module):
            def __init__(self):
                super(FeatureExtractor, self).__init__()
                self.feature_extractor = make_roi_box_feature_extractor(cfg, 256)

            def forward(self, features, proposals):
                bbox, objectness = proposals

                proposals = BoxList(bbox, (t_width, t_height), mode="xyxy")
                proposals.add_field("objectenss", objectness)

                x = self.feature_extractor(features, [proposals])

                return x

        feature_extractor = FeatureExtractor()

        inputs, outputs = self.run_model(feature_extractor, (sample_features, sample_proposals))

        if VALIDATION_TYPE == "IO":
            onnx_io = io.BytesIO()
            onnx_io = "./demo/onnx_test_models/feature_extractor.onnx"

        torch.onnx.export(feature_extractor, inputs, onnx_io,
                            input_names=["feature_0", "feature_1", "feature_2", "feature_3", "feature_4", "bbox", "objectness"],

        self.ort_validate(onnx_io, inputs, outputs)

if __name__ == '__main__':

Expected behavior

ERROR: test_feature_extractor (__main__.ROIBoxHeadTester)
Traceback (most recent call last):
  File "/home/jade/Workspace/maskrcnn/maskrcnn-benchmark-1-3-1/demo/unittest/onnx/export/", line 45, in test_feature_extractor
    self.ort_validate(onnx_io, inputs, outputs)
  File "/home/jade/Workspace/maskrcnn/maskrcnn-benchmark-1-3-1/demo/unittest/onnx/export/", line 73, in ort_validate
    ort_session = onnxruntime.InferenceSession(onnx_io)
  File "/home/jade/.pyenv/versions/maskrcnn-benchmark-1-3-1/lib/python3.7/site-packages/onnxruntime/capi/", line 25, in __init__
  File "/home/jade/.pyenv/versions/maskrcnn-benchmark-1-3-1/lib/python3.7/site-packages/onnxruntime/capi/", line 43, in _load_model
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(float)' of input parameter (44) of operator (Equal) in node () is invalid.

Ran 1 test in 0.911s

FAILED (errors=1)


PyTorch version: 1.3.1 Is debug build: No CUDA used to build PyTorch: 10.1.243

OS: Ubuntu 18.04.3 LTS GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0 CMake version: version 3.10.2

Python version: 3.7 Is CUDA available: Yes CUDA runtime version: 10.1.243 GPU models and configuration: GPU 0: GeForce RTX 2080 Ti Nvidia driver version: 440.48.02 cuDNN version: Probably one of the following: /usr/local/cuda-10.0/targets/x86_64-linux/lib/ /usr/local/cuda-10.1/targets/x86_64-linux/lib/ /usr/local/cuda-10.2/targets/x86_64-linux/lib/

Versions of relevant libraries: [pip3] numpy==1.18.1 [pip3] onnx==1.6.0 [pip3] onnxruntime==1.1.0 [pip3] Pillow==6.2.2 [pip3] torch==1.3.1 [pip3] torchvision==0.4.2 [conda] Could not collect

schyun9212 commented 4 years ago

This error was caused there was an opearation to compare Float tensor vs Constant.

fix the operation like

idx_in_level = torch.nonzero(levels == level).squeeze(1)


idx_in_level = torch.nonzero(levels.type(torch.int32) == level).squeeze(1)