jeya-maria-jose / UNeXt-pytorch

Official Pytorch Code base for "UNeXt: MLP-based Rapid Medical Image Segmentation Network", MICCAI 2022
https://jeya-maria-jose.github.io/UNext-web/
MIT License
459 stars 76 forks source link

pth to ONNX #36

Open lhehejunl opened 1 year ago

lhehejunl commented 1 year ago

Hi,bigcow, can you provide the code which model.pth convert to model.onnx? thanks.

MinGiSa commented 8 months ago

from archs import UNext import torch

if name == 'main': with torch.no_grad(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = UNext(1,3,False).to(device) checkpointPath = "Your Model File Path" destPath = "Save Path, You Want" model.load_state_dict(torch.load(checkpointPath)) model.eval()

    imgHeight = 256
    imgWidth = 256
    batchSize = 1
    dummyInput = torch.rand(batchSize, 3, imgHeight, imgWidth).to(device)
    inputNames = ["input"]
    outputNames = ["output"]
    dynamicAxes = {"input": {0: "batchSize"}, "output": {0: "batchSize"}}

    torch.onnx.export(model,
                        dummyInput,
                        destPath,
                        input_names=inputNames,
                        output_names=outputNames,
                        dynamic_axes=dynamicAxes,
                        verbose=True)