schyun9212 / maskrcnn-benchmark

Converting maskrcnn-benchmark model to TorchScript or ONNX
MIT License
2 stars 0 forks source link

Failed to convert pytorch model to onnx #6

Open schyun9212 opened 4 years ago

schyun9212 commented 4 years ago

🐛 Bug

I exported submodules backbone, rpn, roi_heads. backbone is successfully created and executable. But in case of rpn, it has two inputs 'image' and 'features'. Although model was exported, no graphs related to images were created, which led to wrong results.

Considering that the cause of the bug was caused by multiple inputs, we have created a model following the two models(backbone + rpn). But graph of exported model same as backbone.

In case of roi_heads, we can exported model. But failed to load on onnxruntime. And cannot create shape infer model, graph showed bbox as a unique input.

/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/demo/unit_test.py:63: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  image_list = ImageList(image.unsqueeze(0), [(int(image.size(-2)), int(image.size(-1)))])
/home/jade/.pyenv/versions/maskrcnn-tracing-latest/lib/python3.7/site-packages/torch/onnx/symbolic_helper.py:198: UserWarning: You are trying to export the model with onnx:Resize for ONNX opset version 10. This operator might cause results to not match the expected results by PyTorch.
ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. Attributes to determine how to transform the input were added in onnx:Resize in opset 11 to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).
We recommend using opset 11 and above for models using this operator. 
  "" + str(_export_onnx_opset_version) + ". "
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/demo/unit_test.py:91: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  image_list = ImageList(image.unsqueeze(0), [(int(image.size(-2)), int(image.size(-1)))])
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/structures/bounding_box.py:21: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/structures/bounding_box.py:26: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if bbox.size(-1) != 4:
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/inference.py:94: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/box_coder.py:87: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/box_coder.py:89: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/box_coder.py:91: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/box_coder.py:93: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1
/home/jade/.pyenv/versions/maskrcnn-tracing-latest/lib/python3.7/site-packages/torch/tensor.py:426: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  'incorrect results).', category=RuntimeWarning)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/structures/bounding_box.py:216: TracerWarning: There are 4 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.bbox[:, 0].clamp_(min=0, max=self.size[0] - TO_REMOVE)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/structures/bounding_box.py:217: TracerWarning: There are 4 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.bbox[:, 1].clamp_(min=0, max=self.size[1] - TO_REMOVE)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/structures/bounding_box.py:218: TracerWarning: There are 4 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.bbox[:, 2].clamp_(min=0, max=self.size[0] - TO_REMOVE)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/structures/bounding_box.py:219: TracerWarning: There are 4 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.bbox[:, 3].clamp_(min=0, max=self.size[1] - TO_REMOVE)
/home/jade/.pyenv/versions/maskrcnn-tracing-latest/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py:1881: UserWarning: Exporting aten::index operator of advanced indexing in opset 10 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.
  "If indices include negative values, the exported graph will produce incorrect results.")
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/demo/unit_test.py:123: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  image_list = ImageList(image.unsqueeze(0), [(int(image.size(-2)), int(image.size(-1)))])
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/inference.py:95: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  objectness, topk_idx = objectness.topk(pre_nms_top_n, dim=1, sorted=True)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/inference.py:176: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness))
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/poolers.py:84: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  for i, b in enumerate(boxes)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/poolers.py:106: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  num_rois = len(rois)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py:62: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  boxes_per_image = [len(box) for box in boxes]
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/structures/bounding_box.py:216: TracerWarning: There are 3 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.bbox[:, 0].clamp_(min=0, max=self.size[0] - TO_REMOVE)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/structures/bounding_box.py:217: TracerWarning: There are 3 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.bbox[:, 1].clamp_(min=0, max=self.size[1] - TO_REMOVE)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/structures/bounding_box.py:218: TracerWarning: There are 3 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.bbox[:, 2].clamp_(min=0, max=self.size[0] - TO_REMOVE)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/structures/bounding_box.py:219: TracerWarning: There are 3 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.bbox[:, 3].clamp_(min=0, max=self.size[1] - TO_REMOVE)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py:122: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  for j in range(1, num_classes):
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py:131: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  num_labels = len(boxlist_for_class)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py:138: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  number_of_detections = len(result)
/home/jade/Workspace/maskrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py:47: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  boxes_per_image = [len(box) for box in boxes]
Segmentation fault (core dumped)

To Reproduce

import numpy as np

import requests
import torch

from PIL import Image
from maskrcnn_benchmark.config import cfg
from predictor import COCODemo
from maskrcnn_benchmark.structures.image_list import ImageList
from maskrcnn_benchmark.structures.bounding_box import BoxList

from demo.utils import imshow, masking_image, load_image
from transform import transform_image
import onnx
import onnxruntime as ort

from demo.onnx.utils import infer_shapes

#%%
ONNX_OPSET_VERSION = 10

config_file = "../configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml"
cfg.merge_from_file(config_file)
cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
cfg.freeze()

coco_demo = COCODemo(
    cfg,
    confidence_threshold=0.7,
    min_image_size=800,
)

for param in coco_demo.model.parameters():
    param.requires_grad = False

original_image = load_image("./sample.jpg")
image, t_width, t_height = transform_image(cfg, original_image)

height, width = original_image.shape[:-1]

# Gradient must be deactivated
for p in coco_demo.model.parameters():
    p.requires_grad_(False)

BACKBONE_PATH = "backbone.onnx"

class Backbone(torch.nn.Module):
    def __init__(self):
        super(Backbone, self).__init__()

    def forward(self, image):
        image_list = ImageList(image.unsqueeze(0), [(int(image.size(-2)), int(image.size(-1)))])

        result = coco_demo.model.backbone(image_list.tensors)
        return result

backbone = Backbone()
backbone.eval()
expected_backbone_result = backbone(image)

torch.onnx.export(backbone, (image, ), BACKBONE_PATH,
                    verbose=False,
                    do_constant_folding=True,
                    opset_version=ONNX_OPSET_VERSION, input_names=["i_image"])

infer_shapes(BACKBONE_PATH, "backbone.shape.onnx")

ort_session = ort.InferenceSession(BACKBONE_PATH)
backbone_result = ort_session.run(None, {ort_session.get_inputs()[0].name: image.numpy()})
features = (torch.from_numpy(np.asarray(backbone_result[0])),)

RPN_PATH = "rpn.onnx"

class RPN(torch.nn.Module):
    def __init__(self):
        super(RPN, self).__init__()

    def forward(self, image, features):
        image_list = ImageList(image.unsqueeze(0), [(int(image.size(-2)), int(image.size(-1)))])
        result = coco_demo.model.rpn(image_list, features)[0][0]
        # rpn has extra field "objectness"
        result = (result.bbox,) + tuple(f for f in (result.get_field(field) for field in sorted(result.fields())) if isinstance(f, torch.Tensor))
        return result

rpn = RPN()
rpn.eval()
expected_rpn_result = rpn(image, features)

torch.onnx.export(rpn, (image, features), RPN_PATH,
                    verbose=False,
                    do_constant_folding=True,
                    opset_version=ONNX_OPSET_VERSION, input_names=["image", "features"])

infer_shapes(RPN_PATH, "rpn.shape.onnx")

BACKBONE_RPN_PATH = "backbone+rpn.onnx"

class BackboneRPN(torch.nn.Module):
    def __init__(self):
        super(BackboneRPN, self).__init__()

    def forward(self, image):
        image_list = ImageList(image.unsqueeze(0), [(int(image.size(-2)), int(image.size(-1)))])

        features = coco_demo.model.backbone(image_list.tensors)
        result = coco_demo.model.rpn(image_list, features)[0][0]
        # rpn has extra field "objectness"
        result = (result.bbox,) + tuple(f for f in (result.get_field(field) for field in sorted(result.fields())) if isinstance(f, torch.Tensor))
        return result

backbone_rpn = BackboneRPN()
backbone_rpn.eval()
expected_backbone_rpn_result = backbone_rpn(image)

torch.onnx.export(backbone_rpn, (image, ), BACKBONE_RPN_PATH,
                    verbose=False,
                    do_constant_folding=True,
                    opset_version=ONNX_OPSET_VERSION, input_names=["image"])

infer_shapes(BACKBONE_PATH, "backbone+rpn.shape.onnx")
ort_session = ort.InferenceSession(BACKBONE_RPN_PATH)
backbone_rpn_result = ort_session.run(None, {ort_session.get_inputs()[0].name: image.numpy()})

ROI_PATH = "roi.onnx"

class ROI(torch.nn.Module):
    def __init__(self):
        super(ROI, self).__init__()

    def forward(self, features, proposals):
        bbox, objectness = proposals

        proposals = BoxList(bbox, (t_width, t_height), mode="xyxy")
        proposals.add_field("objectenss", objectness)

        _, result, _ = coco_demo.model.roi_heads(features, [proposals])

        result = (result[0].bbox,
                result[0].get_field("labels"),
                result[0].get_field("mask"),
                result[0].get_field("scores"))

        return result

roi = ROI()
roi.eval()
expected_roi_result = roi(expected_backbone_result, expected_rpn_result)

torch.onnx.export(roi, (expected_backbone_result, expected_backbone_rpn_result), ROI_PATH,
                    verbose=False,
                    do_constant_folding=True,
                    opset_version=ONNX_OPSET_VERSION, input_names=["image", "proposals"])

infer_shapes(ROI_PATH, "roi.shape.onnx")

Environment

PyTorch version: 1.3.1 Is debug build: No CUDA used to build PyTorch: 10.1.243

OS: Ubuntu 18.04.3 LTS GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0 CMake version: version 3.10.2

Python version: 3.7 Is CUDA available: Yes CUDA runtime version: 10.1.243 GPU models and configuration: GPU 0: GeForce RTX 2080 Ti Nvidia driver version: 440.44 cuDNN version: Probably one of the following: /usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so.7 /usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.5

Versions of relevant libraries: [pip3] numpy==1.18.1 [pip3] onnx==1.6.0 [pip3] onnxruntime==1.1.0 [pip3] onnxruntime-gpu==1.1.0 [pip3] torch==1.3.1 [pip3] torchvision==0.4.2 [conda] Could not collect

schyun9212 commented 4 years ago

I used opset10. To avoid incorrect resize, I should change opset version to 11.

UserWarning: You are trying to export the model with onnx:Resize for ONNX opset version 10. This operator might cause results to not match the expected results by PyTorch.
ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. Attributes to determine how to transform the input were added in onnx:Resize in opset 11 to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).

I removed unnecessary type casting to remove this error

TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
# from
image_list = ImageList(image.unsqueeze(0), [(int(image.size(-2)), int(image.size(-1)))])
# to
image_list = ImageList(image.unsqueeze(0), [(image.size(-2), image.size(-1))])

But this causes wrong result