SHI-Labs / OneFormer

OneFormer: One Transformer to Rule Universal Image Segmentation, arxiv 2022 / CVPR 2023
https://praeclarumjj3.github.io/oneformer
MIT License
1.44k stars 130 forks source link

Convert Model to TensorRT #16

Closed abhigoku10 closed 1 year ago

abhigoku10 commented 1 year ago

@honghuis @SkalskiP thanks for sharing the source code , Just wante dto knw can we convert this model to Tensorrt or ONNX format ? if so please share the conversion and inference script

Thanks in advance

SkalskiP commented 1 year ago

I'm not sure if the team has any ready solution you could use, but if not, I have a backup suggestion for you. OneFormer is built with Detectron2. I deployed a similar segmentation model in the past, and I used this script to convert, and it worked.

praeclarumjj3 commented 1 year ago

Hi @abhigoku10, thanks for your interest in our work. We don't have an already written script to convert the model to TensorRT or ONNX format.

I think the conversion should be pretty straightforward. You can try following the script shared by @SkalskiP (thanks for sharing!). If you need any assistance from our side, please let us know.

praeclarumjj3 commented 1 year ago

Feel free to re-open the issue if you need any help.

tomhog commented 1 year ago

i've managed to export the model to onnx, there's a few issues at the moment to get it working

  1. I wasn't sure how to get MSDeformAttnFunction to export (Think it's possible) so commented it out and used the ms_deform_attn_core_pytorch fallback
  2. Onnx expects just "image" in the input dictionary, there was no way to pass "task" so I just did a check if it was missing and added one to the dict so it didn't crash
  3. Onnx expects only a list of tensors as output so as I was only interested in semantic I just output that
  4. There's an op used, can't remember which, that pytorch 1.10 onnx export didn't support. You need at least opset 16, so I had to switch to pytorch 1.12.1, torchvision 0.13.1, then build detectron2 from source as couldn't seem to find a compatible prebuilt.

With those changes the export_model.py script worked with a few small changes and I was able to use the model in onnxruntime.

fernandorovai commented 1 year ago

Hi @tomhog, could you please share the changes you did for exporting to onnx? Thanks a ton!

tomhog commented 1 year ago

Hi @fernandorovai

Certainly, though as I said it's a bit of a mess at the moment. I will tidy it up, I just needed to prove it can work before moving forward. Also worth noting I'm not sure if the onnx file would work on GPU, not tried yet.

hack to ms_deform_attn so it doesn't use the custom op (and why it might not work on gpu)

Changes to oneformer model to handle missing task in input dict, and only output one task type

script to export the onnx file

demo using the onnx file

Export command would look something like

python ./demo/export_model.py --config-file ./configs/ade20k/convnext/oneformer_convnext_large_bs16_160k_2gpu.yaml \
   --format onnx \
   --export-method tracing \
   --sample-image ../datasets/512x512.jpg \
   --run-eval \
   --output ./output \
   MODEL.IS_TRAIN False MODEL.IS_DEMO True MODEL.WEIGHTS ./output/ade20k_convnext_large/model_0099999.pth

Currently the dimensions of the sample image need to match the dimensions the model was trained at.

I think the rest of that commit it just noise (I was having odd issues with pybind11) and a few of my own bash scripts for convenience.

If it works for you please let me know and maybe we could work on a cleaner approach to add official support to OneFormer.

Tom PS Remember from my post above you have to use pytorch 1.12.1, torchvision 0.13.1

John0x commented 1 year ago

@tomhog have you tested gpu support yet? I was hoping to use OneFormer as an onnx model, but not having GPU support would mean, that I have to use a different model or implement my own.

AAAstorga commented 1 year ago

I'm very new to ML, so I apologize if my question is naive. Is it possible to convert this model to PyTorch Lite? I'm curious to see if it's possible to use this model with https://playtorch.dev/.

I was hoping to follow this: https://playtorch.dev/docs/tutorials/prepare-custom-model/

But I don't think it's that straightforward. I would appreciate any guidance if possible! Thank you.

praeclarumjj3 commented 1 year ago

Hi @AAAstorga, thanks for your interest in our work.

You should follow the tutorial on using DETR with PlayTorch, as OneFormer and DETR are both built using detectron2. https://playtorch.dev/docs/tutorials/snacks/object-detection/.

Also, it might be better to create a new issue for PlayTorch.

AAAstorga commented 1 year ago

Thank you @praeclarumjj3 - I appreciate the response. Do you have any directions on how to load the pretrained model in Python with PyTorch so I can convert it to a mobile friendly version? This might sound like a simple question, but I'm just starting to learn about all of this. Thank you!

praeclarumjj3 commented 1 year ago

You can refer PyTorch tutorials: https://pytorch.org/tutorials/

SHUNLU-1 commented 1 year ago

@tomhog I'm using your code to convert onnx and I'm having this problem is there any way to fix it look forward to your reply! Traceback (most recent call last): File "./demo/export_model.py", line 249, in <module> exported_model = export_tracing(torch_model, sample_inputs) File "./demo/export_model.py", line 154, in export_tracing torch.onnx.export(traceable_model, (image,), f, verbose=True, opset_version=16, do_constant_folding=False, input_names=["input"], output_names=["output"], dynamic_axes={})#STABLE_ONNX_OPSET_VERSION) File "/home/lbm/lbm_src/conda_env/env/oneformer/lib/python3.8/site-packages/torch/onnx/__init__.py", line 350, in export return utils.export( File "/home/lbm/lbm_src/conda_env/env/oneformer/lib/python3.8/site-packages/torch/onnx/utils.py", line 163, in export _export( File "/home/lbm/lbm_src/conda_env/env/oneformer/lib/python3.8/site-packages/torch/onnx/utils.py", line 1110, in _export ) = graph._export_onnx( # type: ignore[attr-defined] RuntimeError: ONNX export failed: Couldn't export Python operator NATTEN2DQKRPBFunction

lianglilong-gloritytech commented 6 months ago

Hi @fernandorovai

Certainly, though as I said it's a bit of a mess at the moment. I will tidy it up, I just needed to prove it can work before moving forward. Also worth noting I'm not sure if the onnx file would work on GPU, not tried yet.

hack to ms_deform_attn so it doesn't use the custom op (and why it might not work on gpu)

Changes to oneformer model to handle missing task in input dict, and only output one task type

script to export the onnx file

demo using the onnx file

Export command would look something like

python ./demo/export_model.py --config-file ./configs/ade20k/convnext/oneformer_convnext_large_bs16_160k_2gpu.yaml \
   --format onnx \
   --export-method tracing \
   --sample-image ../datasets/512x512.jpg \
   --run-eval \
   --output ./output \
   MODEL.IS_TRAIN False MODEL.IS_DEMO True MODEL.WEIGHTS ./output/ade20k_convnext_large/model_0099999.pth

Currently the dimensions of the sample image need to match the dimensions the model was trained at.

I think the rest of that commit it just noise (I was having odd issues with pybind11) and a few of my own bash scripts for convenience.

If it works for you please let me know and maybe we could work on a cleaner approach to add official support to OneFormer.

Tom PS Remember from my post above you have to use pytorch 1.12.1, torchvision 0.13.1

@tomhog Hi, sorry to bother you. Have you ever encountered the problem of significant drop in accuracy when using the onnx model to predict? I compared the output with the same input, the difference is in MSDeformAttnFunction.apply and ms_deform_attn_core_pytorch. Picture 1 below is a comparison of the output of onnx using cuda and the original model. Picture 2 below is a comparison of the output of onnx when using CPU and the original model.

Obviously the accuracy is slightly worse when using CPU, but it is directly 0 when using CUDA. Have you encountered this problem? Or can you provide some suggestions? screenshot-3 screenshot-4

roboserg commented 3 months ago

Was anyone able to run the ONNX model on the GPU and having the same metrics as with Pytorch inference (i.e. AP50) ? @tomhog

xuyuxiu83 commented 3 months ago

Hi @fernandorovai

Certainly, though as I said it's a bit of a mess at the moment. I will tidy it up, I just needed to prove it can work before moving forward. Also worth noting I'm not sure if the onnx file would work on GPU, not tried yet.

hack to ms_deform_attn so it doesn't use the custom op (and why it might not work on gpu)

Changes to oneformer model to handle missing task in input dict, and only output one task type

script to export the onnx file

demo using the onnx file

Export command would look something like

python ./demo/export_model.py --config-file ./configs/ade20k/convnext/oneformer_convnext_large_bs16_160k_2gpu.yaml \
   --format onnx \
   --export-method tracing \
   --sample-image ../datasets/512x512.jpg \
   --run-eval \
   --output ./output \
   MODEL.IS_TRAIN False MODEL.IS_DEMO True MODEL.WEIGHTS ./output/ade20k_convnext_large/model_0099999.pth

Currently the dimensions of the sample image need to match the dimensions the model was trained at.

I think the rest of that commit it just noise (I was having odd issues with pybind11) and a few of my own bash scripts for convenience.

If it works for you please let me know and maybe we could work on a cleaner approach to add official support to OneFormer.

Tom PS Remember from my post above you have to use pytorch 1.12.1, torchvision 0.13.1

Thank you very much for providing the ONNX format conversion code! I have successfully obtained the ONNX file and performed inference. Have you tried converting the model to TensorRT format? I encountered an issue where the BiasGelu operator is not supported during this process. How can I resolve this? 1ee37751-d658-41e6-8673-9690f3eb264f