open-mmlab / mmcv

OpenMMLab Computer Vision Foundation
https://mmcv.readthedocs.io/en/latest/
Apache License 2.0
5.86k stars 1.63k forks source link

failed use jit.trace to trace model with Correlation #1645

Closed fortuneko closed 2 years ago

fortuneko commented 2 years ago

to export flownet2,follow the steps below:

step1 modify flownet_decoder.py to return tensor list

+++ b/mmflow/models/decoders/flownet_decoder.py
@@ -359,10 +359,13 @@ class FlowNetSDecoder(BaseDecoder):
         flow_result = F.interpolate(
             flow_result, size=(H, W), mode='bilinear', align_corners=False)
         # reshape [2, H, W] to [H, W, 2]
+        # flow_result = flow_result.permute(0, 2, 3,
+        #                                   1).cpu().data.numpy() * self.flow_div
         flow_result = flow_result.permute(0, 2, 3,
-                                          1).cpu().data.numpy() * self.flow_div
+                                          1) * self.flow_div
         # unravel batch dim
         flow_result = list(flow_result)
+        return flow_result

step2

code to trace model

    model = init_model(args.config, args.checkpoint, device=args.device)
    model.forward = functools.partial(model.forward,test_mode=True)
    device = next(model.parameters()).device  # model device
    img = torch.rand(1, 6, 320,320).float().to(device)
    traced_script_module = torch.jit.trace(
        model, img)
    traced_script_module.save('traced_flownet.pt')

step3

exect with specific config and checkpoints,failed with errors:

 File "image_demo.py", line 79, in export_jit
    traced_script_module.save(outfile)
  File "/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/jit/_script.py", line 686, in save
    return self._c.save(str(f), **kwargs)
RuntimeError: 
Could not export Python function call 'CorrelationFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/mmcv/ops/correlation.py(186): forward
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/user/workspace/gitcode/mmflow/mmflow/models/utils/correlation_block.py(68): forward
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/user/workspace/gitcode/mmflow/mmflow/models/encoders/flownet_encoder.py(157): forward
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/user/workspace/gitcode/mmflow/mmflow/models/flow_estimators/flownet.py(109): extract_feat
/home/user/workspace/gitcode/mmflow/mmflow/models/flow_estimators/flownet2.py(130): forward_test
/home/user/workspace/gitcode/mmflow/mmflow/models/flow_estimators/base.py(61): forward
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/jit/_trace.py(965): trace_module
/home/user/anaconda3/envs/yolox/lib/python3.7/site-packages/torch/jit/_trace.py(750): trace
image_demo.py(64): export_jit
image_demo.py(86): main
image_demo.py(101): <module>

I know this problem may come with jit has no runtime ops for CorrelationFunction, is there any way to workaround with it?

grimoire commented 2 years ago

Sorry for the late reply. As far as I know, export custom autograd.Function is not supported in PyTorch. You can try create a custom extension of that ops. Here is an example in MMDeploy: torchscript_support