HUuxiaobin / HitNet

Camouflaged Object Detection
43 stars 9 forks source link

How to convert pth model to onnx? #11

Open meifannao opened 1 year ago

meifannao commented 1 year ago

My python code for converting the model is as follows, but I don't get the onnx model

import torch
import torch.onnx
from lib.pvt import Hitnet

def pth_to_onnx(input, checkpoint, onnx_path, input_names=['input'], output_names=['output1', 'output2'], device='cuda'):
    if not onnx_path.endswith('.onnx'):
        print('Warning! The onnx model name is not correct,\
              please give a name that ends with \'.onnx\'!')
        return 0
    print(torch.cuda.is_available())
    model = Hitnet().cuda() 
    model.load_state_dict(torch.load(checkpoint))
    model.eval()

    torch.onnx.export(model, input, onnx_path, verbose=False, input_names=input_names, output_names=output_names) 
    print("Exporting .pth model to onnx model has been successful!")

if __name__ == '__main__':
    checkpoint = './Net_epoch_best.pth'
    onnx_path = './Net_epoch_best.onnx'
    input = torch.randn(1, 3, 480, 640, device='cuda')
    pth_to_onnx(input, checkpoint, onnx_path)
HUuxiaobin commented 10 months ago

we don't implement the onnx conversion, but only provide a onnx template