pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.96k stars 6.92k forks source link

Can`t export mobilenetv3 model to onnx #3463

Open jkparuchuri opened 3 years ago

jkparuchuri commented 3 years ago

model = models.mobilenet_v3_small(pretrained=True) input_np = np.random.uniform(0, 1, (1, 3, 224, 224)) input_var = torch.FloatTensor(input_np) torch.onnx.export(model, args=(input_var), f="cnn.onnx", verbose=False, input_names=["input"], output_names=["output"])

Error: RuntimeError: Exporting the operator hardsigmoid to ONNX opset version 9 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

cc @neginraoof

datumbox commented 3 years ago

@jkparuchuri Thanks for reporting. MobileNetV3 makes use of hardsigmoid, so if that's not supported by ONNX it can't run on it.

@neginraoof I wonder if you could recommend a workaround? How hard it would be to add support of Hardsigmoid?

zhiqwang commented 3 years ago

Hi

It seems that the nn.Hardswish caused this problem, actually the nightly version of PyTorch has addressed this problem, so the unit-test is passed.

If you are using PyTorch 1.7.x, you can replace it to an export friendly version of Hardswish as below and set ONNX opset version to 11.

class Hardswish(nn.Module):
    """
    Export-friendly version of nn.Hardswish()
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * F.hardtanh(x + 3, 0., 6.) / 6.
jkparuchuri commented 3 years ago

@zhiqwang I had tried on pytorch 1.8 and torchvision 0.9 released today; still I get the same error. Are you sure it got fixed already ?

RuntimeError: Exporting the operator hardsigmoid to ONNX opset version 9 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

zhiqwang commented 3 years ago

Hi @jkparuchuri

Sorry, I haven't fully tested my proposal and you are right. I only tested exporting nn.Hardswish before. I've missed the Hardsigmoid in

https://github.com/pytorch/vision/blob/c991db82abba12e664eeac14c9b643d0f1f1a7df/torchvision/models/mobilenetv3.py#L35

As you mentioned, torch.export doesn't support this operator,

I replaced this operator with F.hardtanh to solve this problem (I've tested the whole model now). You can do something as following to address this error.

diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py
index 1e2606d..fdf1f6d 100644
--- a/torchvision/models/mobilenetv3.py
+++ b/torchvision/models/mobilenetv3.py
@@ -32,7 +32,7 @@ class SqueezeExcitation(nn.Module):
         scale = self.fc1(scale)
         scale = self.relu(scale)
         scale = self.fc2(scale)
-        return F.hardsigmoid(scale, inplace=inplace)
+        return F.hardtanh(scale + 3, 0., 6., inplace=inplace) / 6.

     def forward(self, input: Tensor) -> Tensor:
         scale = self._scale(input, True)

Now there are two ways to solve this bug,

  1. Native support exporting F.hardsigmoid to onnx.
  2. Replace F.hardsigmoid with F.hardtanh that is friendly for exporting and equal numerically as I did above.

And the export of mobilenetv3 to onnx is missing in the unit-test, maybe we could add a test like test_shufflenet_v2_dynamic_axes.

cc @datumbox @fmassa What is your suggestion?

jkparuchuri commented 3 years ago

Thank you @zhiqwang

fmassa commented 3 years ago

Another option could to register a custom ONNX op for hardsigmoid in https://github.com/pytorch/vision/blob/master/torchvision/ops/_register_onnx_ops.py , so that we can keep torchvision using the faster hardsigmoid while ONNX uses hardtanh.

@neginraoof what do you think it would be preferable? Are there plans on adding those missing operators to ONNX?

ThomAub commented 3 years ago

I agree with @fmassa and I think there is no need for a custom op for HardSigmoid. I saw https://github.com/onnx/onnx/blob/master/docs/Operators.md#HardSigmoid so I might be mistaken but it's maybe just adding and registering this operator on the pytorch side ?

zhiqwang commented 3 years ago

Hi @fmassa , @ThomAub - It seems that https://github.com/pytorch/pytorch/issues/49649 and https://github.com/pytorch/pytorch/pull/54193 are working on this issue now.