toandaominh1997 / EfficientDet.Pytorch

Implementation EfficientDet: Scalable and Efficient Object Detection in PyTorch
MIT License
1.44k stars 305 forks source link

How to convert to onnx model? #53

Open SHMathRabbit opened 4 years ago

SHMathRabbit commented 4 years ago

@toandaominh1997 Hi, thanks for your Great work. I have tried to convert efficientdet-b0.pth to onnx model, and my code is follow :

def to_onnx(checkpoint_name,save_onnx_name='efficientdet.onnx', size_image=(512, 512)):
    from torch.autograd import Variable
    import torch.onnx as onnx

    checkpoint = torch.load(checkpoint_name)#,map_location=lambda storage, loc: storage)
    num_class = checkpoint['num_class']
    network = checkpoint['network']
    model=EfficientDet(num_classes=num_class,
                     network=network,
                     W_bifpn=EFFICIENTDET[network]['W_bifpn'],
                     D_bifpn=EFFICIENTDET[network]['D_bifpn'],
                     D_class=EFFICIENTDET[network]['D_class'],
                     is_training=False
                     )

    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    model.train(False)

    dummpy_input = Variable(torch.randn(1,3,512,512))
    output = onnx.export(model, dummpy_input, save_onnx_name,verbose=True, export_params=True)
to_onnx('checkpoint_VOC_efficientdet-d0_206.pth')

I can get efficientdet.onnx file when I run the above code. Moreover I use onnx.checker.check_model to check the onnx model and no error.

But when I use the same image to run demo.py and onnx_test.py, demo.py has result(can predict object) , and onnx_test.py is not. my onnx_test.py is follow:

import onnxruntime as rt
import cv2
import numpy as np

def run(image_name, onnx_path):
    img=cv2.imread(image_name)
    img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (512,512))
    img = img.reshape(1,3,512,512)
    img= img.astype(np.float32)

    sess = rt.InferenceSession(onnx_path)
    input_name = sess.get_inputs()[0].name
    output_name1 = sess.get_outputs()[0].name
    output_name2 = sess.get_outputs()[1].name
    output_name3 = sess.get_outputs()[2].name
    print(output_name1, output_name2, output_name3)
    output_name=[output_name1,output_name2, output_name3]
    pred_onnx = sess.run(output_name, {input_name: img})
    print(pred_onnx)

when run onnx_test.py, output :

2019-12-31 13:32:44.800826576 [W:onnxruntime:, graph.cc:2412 CleanUnusedInitializers] Removing initializer 'efficientnet._blocks.8._bn1.weight'. It is not used by any node and should be removed from the model.
2019-12-31 13:32:44.800834976 [W:onnxruntime:, graph.cc:2412 CleanUnusedInitializers] Removing initializer 'efficientnet._blocks.8._bn2.bias'. It is not used by any node and should be removed from the model.
2019-12-31 13:32:44.800843492 [W:onnxruntime:, graph.cc:2412 CleanUnusedInitializers] Removing initializer 'efficientnet._blocks.8._bn2.num_batches_tracked'. It is not used by any node and should be removed from the model.
2019-12-31 13:32:44.800852497 [W:onnxruntime:, graph.cc:2412 CleanUnusedInitializers] Removing initializer 'efficientnet._blocks.5._depthwise_conv.weight'. It is not used by any node and should be removed from the model.
2019-12-31 13:32:44.800861084 [W:onnxruntime:, graph.cc:2412 CleanUnusedInitializers] Removing initializer 'efficientnet._blocks.8._bn2.running_mean'. It is not used by any node and should be removed from the model.
459 460 461
[array([], dtype=float32), array([], dtype=float32), array([], shape=(0, 4), dtype=float32)]

Please take the time to answer the above three questions, thank you very much!

Mut1nyJD commented 4 years ago

I am not the author itself but I think I can answer some of your questions. There are certainly some issues in exporting the network to ONNX even with op-level 11 at least with Pytorch version 1.3 I have not tried with 1.4 yet. But the ONNX model exported is pretty much unusable even though it loads & runs, it produces no output and the problem seem to be in some of the later processing layers of the network that create the final prediction outputs they seem to be not really exportable as they use functionality that can not be probably traced by the PyTorch JIT. So you probably would have to modify the network structure to take this post-processing bit out of the network itself and just output the last layer of the feature extraction before and then try reexporting and move the post-processing inside the code executing your ONNX model.