open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
8.23k stars 2.61k forks source link

Export ONNX model error #1870

Closed Aspirinkb closed 2 years ago

Aspirinkb commented 2 years ago

Thanks for your error report and we appreciate it a lot.

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. The bug has not been fixed in the latest version.

Describe the bug Can not convert trained model into onnx format when using tools/pytorch2onnx.py. When run torch.onnx.export, there is a error: TypeError: forward() got multiple values for argument 'img_metas'. I read the code, the model's forward function is wrapped and img_metas is set. I can not figure out why arising this error while just set img_metas once.

Reproduction

  1. What command or script did you run?
python tools/pytorch2onnx.py \
./configs/convnext_spot/upernet_convnext_base_fp16_960x960_10k_spot2.py \
--checkpoint /general-user/frank/spot/data/spot/upernet_convnext_base_fp16_960x960_10k_spot2/best_mDice_iter_8100.pth \
--output-file /general-user/frank/spot/mmsegmentation/work_dirs/upernet_convnext_base_fp16_960x960_10k_spot2/model.onnx \
--input-img /general-user/frank/spot/data/spot/diff_images2/val/IMG_20220707_162730_1.jpg \
--show \
--verify \
--dynamic-export \
--cfg-options \
  model.test_cfg.mode="whole"
  1. Did you make any modifications on the code or config? Did you understand what you have modified? No

  2. What dataset did you use? My own

Environment

  1. Please run python mmseg/utils/collect_env.py to collect necessary environment information and paste it here.
  2. You may add addition that may be helpful for locating the problem, such as
    • How you installed PyTorch [e.g., pip, conda, source]
    • Other environment variables that may be related (such as $PATH, $LD_LIBRARY_PATH, $PYTHONPATH, etc.)

Error traceback

If applicable, paste the error trackback here.

Traceback (most recent call last):
  File "tools/pytorch2onnx.py", line 387, in <module>
    pytorch2onnx(
  File "tools/pytorch2onnx.py", line 196, in pytorch2onnx
    torch.onnx.export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/__init__.py", line 350, in export
    return utils.export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 163, in export
    _export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 1074, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 727, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 602, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 517, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 1175, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/mmcv/runner/fp16_utils.py", line 118, in new_func
    return old_func(*args, **kwargs)
TypeError: forward() got multiple values for argument 'img_metas'

Bug fix

If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

Aspirinkb commented 2 years ago

torch: 1.12.0+cu116

feiqiu-cyber commented 2 years ago

torch: 1.12.0+cu116

hello, have you solved the problem ?i meet the same problem when i run pytorch2onnx.py

Aspirinkb commented 2 years ago

torch: 1.12.0+cu116

hello, have you solved the problem ?i meet the same problem when i run pytorch2onnx.py

No. I guess this is a problem of MMSeg to support onnx export.

xiexinch commented 2 years ago

Hi @Aspirinkb, I can run the following command successfully with torch1.9 on the CPU.

python tools/pytorch2onnx.py configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py --checkpoint checkpoints/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth --input-img demo/demo.png --show --verify  --cfg-options model.test_cfg.mode='whole'

Could you try removing --dynamic-export in your command? And I'll test it with torch 1.12.

xiexinch commented 2 years ago

Hi @Aspirinkb, I can run the following command successfully with torch1.9 on the CPU.

python tools/pytorch2onnx.py configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py --checkpoint checkpoints/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth --input-img demo/demo.png --show --verify  --cfg-options model.test_cfg.mode='whole'

Could you try removing --dynamic-export in your command? And I'll test it with torch 1.12.

Same error with torch1.12. Hi @RunningLeon, and @zhouzaida, Could you take a look at this issue if you're available?

Aspirinkb commented 2 years ago

I change the signature and body of the forward function as following:

    @auto_fp16(apply_to=('img', ))
    def forward(self, img, **kwargs):  # img_metas, return_loss=True
        """Calls either :func:`forward_train` or :func:`forward_test` depending
        on whether ``return_loss`` is ``True``.

        Note this setting will change the expected inputs. When
        ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
        and List[dict]), and when ``resturn_loss=False``, img and img_meta
        should be double nested (i.e.  List[Tensor], List[List[dict]]), with
        the outer list indicating test time augmentations.
        """
        try:
            return_loss = kwargs.pop("return_loss")
            img_metas = kwargs.pop("img_metas")
        except KeyError as e:
            raise Exception(f"Miss params return_loss or img_meats: {e}")
        if return_loss:
            return self.forward_train(img, img_metas, **kwargs)
        else:
            return self.forward_test(img, img_metas, **kwargs)

and export the onnx model. Running by onnxruntime is ok.
But please note that, one should change it back for training models!!!

It is not the correct way to fix the onnx export problem, but I can not do nothing...
Waiting for official bug fix...

AndPuQing commented 2 years ago

Is there any progress on this issue? I have the same problem.

RunningLeon commented 2 years ago

something wrong with auto_fp16 and partial the model forward. cc @zhouzaida Could comment out the @auto_fp16 decorator while using pytorch2onnx.py for now. https://github.com/open-mmlab/mmsegmentation/blob/dd42fa8d0125632371a41a87c20485494c973535/mmseg/models/segmentors/base.py#L96

https://github.com/open-mmlab/mmsegmentation/blob/dd42fa8d0125632371a41a87c20485494c973535/tools/pytorch2onnx.py#L170

RunningLeon commented 2 years ago

@Aspirinkb @AndPuQing Hi, could you guys try with mmdeploy? The deployment feature in mmseg would be removed in the future.

AndPuQing commented 2 years ago

I tried mmdeploy and it worked for me.