SwinTransformer / Swin-Transformer-Object-Detection

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.
https://arxiv.org/abs/2103.14030
Apache License 2.0
1.81k stars 381 forks source link

Onnx Conversion error: ONNX Expand input shape constraint not satisfied. #112

Open montensorrt opened 2 years ago

montensorrt commented 2 years ago

Thanks for your code and we appreciate it a lot. now ,when I run pytorch2onnx.py,the error is:

[W ..\torch\csrc\jit\passes\onnx\shape_type_inference.cpp:419] Warning: Constant folding in symbolic shape inference fails: Dimension out of range (expected to be in range of [-1, 0], but got -7) Exception raised from maybe_wrap_dim at ..\c10/core/WrapDimMinimal.h:33 (most recent call first): 00007FFB64F710D200007FFB64F71070 c10.dll!c10::Error::Error [ @ ] torch._C._jit_pass_onnx_node_shape_type_inference(n, _params_dict, opset_version) RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "..\torch\csrc\jit\passes\onnx\shape_type_inference.cpp":520, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.

what should i do?

what i do:

python tools/deployment/pytorch2onnx.py configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py checkpoint/moby_mask_rcnn_swin_tiny_patch4_window7_3x.pth --test-img demo/test.jpg --verify

my enviroment: python 3.7,pytorch1.9,cu11.1,mmcv1.3.9.

i conversion my save model, and i already trained and save model.

i find someone solution what deal with thing,in https://giters.com/lucastabelini/LaneATT/issues/77 it's looks like nms ?

Looking forward to your reply!

montensorrt commented 2 years ago

onnx not support roll,so i use it replaced, such as:

x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))

        shifted_x = torch.cat((shifted_x[:, self.shift_size:, :, :], shifted_x[:, :self.shift_size, :, :]), dim=1)
        shifted_x = torch.cat((shifted_x[:, :, self.shift_size:, :], shifted_x[:, :, :self.shift_size, :]), dim=2)
        x = shifted_x

and i'm sure the channel no change.