ZHKKKe / MODNet

A Trimap-Free Portrait Matting Solution in Real Time [AAAI 2022]
Apache License 2.0
3.83k stars 637 forks source link

Exporting model to onnx format #10

Closed anilsathyan7 closed 3 years ago

anilsathyan7 commented 3 years ago

How to export model to onnx format? I tried the following with colab demo code , but it showed an error: TypeError: forward() missing 1 required positional argument: 'inference'

from torch.autograd import Variable

model = MODNet(backbone_pretrained=False)
model = nn.DataParallel(model).cuda()

state_dict = torch.load(pretrained_ckpt)
model.load_state_dict(state_dict)
model.eval()
dummy_input = Variable(torch.randn(1, 3, 512, 512))
torch.onnx.export(model.module, dummy_input, '/content/MODNet/modnet.onnx', export_params = True)

BTW the test results looks amazing !!!

ZHKKKe commented 3 years ago

Hi, thanks for your attention!

Please refer to the forward function of MODNet defined in the file MODNet/src/models/modnet.py. This function takes an argument inference as the input. Set inference = True will disable the outputs of LRBranch and HRBranch, i.e., s_p and d_p in the paper, to save the inference time.

I am not familiar with onnx, but it seems that onnx does not support conditional statements in the forward function. A possible solution may be to rewrite the forward function by deleting the code related to the argument inference.

anilsathyan7 commented 3 years ago

Thanks, it worked !!!

Tomas1337 commented 3 years ago

Hi, would you mind sharing your work on the onnx model? I'm trying to do the same but to export it to an OpenVino model after.

I'm pretty new to this so i'm finding my way through. It looks like you exported the entire MODnet to onnx. Wouldnt you have to export each of the three models to ONNX? (LR, HR, FR?)

Tomas1337 commented 3 years ago

Also, "Set inference = True will disable the outputs of LRBranch and HRBranch"

Do you not need the LR and HR branch when doing inference?

manthan2305 commented 3 years ago

@ZHKKKe First of all, Thank you for such an amazing work. I really appreciate it.

As you suggested above, I was able to generate .onnx file by modifying your code but got this warning.

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3103: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed " /usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_helper.py:267: UserWarning: You are trying to export the model with onnx:Upsample for ONNX opset version 9. This operator might cause results to not match the expected results by PyTorch. ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. Attributes to determine how to transform the input were added in onnx:Resize in opset 11 to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode). We recommend using opset 11 and above for models using this operator. "" + str(_export_onnx_opset_version) + ". "

I also tried to inference with this model by using this tutorial.

But I couldn't start InferenceSession. It gives an error as below:

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/upsample.h:277 void onnxruntime::UpsampleBase::ScalesValidation(const std::vector<float>&, onnxruntime::UpsampleMode) const scale >= 1 was false. Scale value should be greater than or equal to 1.

I couldn't figure out. Can you help me or suggest anything so that I can do with it.

Following work is mention below in Colab notebook.

Here is the link of onnx model.

Tomas1337 commented 3 years ago

I've gotten it to work on onnx. Have you tried passing on the opset parameter to be 11? Worked for me

manthan2305 commented 3 years ago

Okay, Thank you for help. I completely misunderstood the error. It also worked for me. I updated the onnx file here.

ZHKKKe commented 3 years ago

@anilsathyan7 @manthanTECHNOZER @Tomas1337 @manthan3C273 Hey, all, sorry for late response (due to my vacation). Would you like to share the converted onnx model and the inference code of the onnx model with the community? If so, you can share the model and the code with me (I will add them into this Repo), or you canb open a pull request (please let me know first under this question). Thanks in advance.

manthan2305 commented 3 years ago

Okay, sure. I'll do it. But I'm getting dependencies of torch for F. interpolate. Because I got to know that your processed image of inference is looking like this.


im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
im_np = np.asarray(im)
im_np = im_np.reshape((im_rh, im_rw, 3))

plt.imshow(im_np)

Output image

Can you help that how can I pre-process input image without F.interpolate. I think, It may be any option in opencv. But I didn't know that much about opencv.

manthan2305 commented 3 years ago

I just found the solution. I did it in wrong way. Soon I'll share the model and code.