milesial / Pytorch-UNet

PyTorch implementation of the U-Net for image semantic segmentation with high quality images
GNU General Public License v3.0
9.3k stars 2.51k forks source link

about export pth to onnx #492

Open JensenHJS opened 7 months ago

JensenHJS commented 7 months ago

There are some errors when exporting to onnx, looking forward to your reply .

My package versions are as follows: torch 2.2.2 torchvision 0.17.2 onnx 1.16.0 onnxruntime 1.17.3

The export script is as follows:

net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'Loading model {args.model}')
state_dict = torch.load(args.model, map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1]) net.load_state_dict(state_dict) net.eval() dummy_input = torch.randn(1, 3, 256, 256) onnx_path = 'unet.onnx' input_names = "data" output_names = "output" torch.onnx.export(net, (dummy_input), onnx_path, input_names=input_names, output_names=output_names, verbose=False)

The error is: Traceback (most recent call last): File "/home/ubuntu/CODE/Pytorch-UNet/export_onnx.py", line 53, in torch.onnx.export(net, (dummy_input), onnx_path, input_names=input_names, output_names=output_names, verbose=False) File "/home/ubuntu/miniconda3/envs/unet/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export _export( File "/home/ubuntu/miniconda3/envs/unet/lib/python3.10/site-packages/torch/onnx/utils.py", line 1613, in _export graph, params_dict, torch_out = _model_to_graph( File "/home/ubuntu/miniconda3/envs/unet/lib/python3.10/site-packages/torch/onnx/utils.py", line 1139, in _model_to_graph graph = _optimize_graph( File "/home/ubuntu/miniconda3/envs/unet/lib/python3.10/site-packages/torch/onnx/utils.py", line 674, in _optimize_graph _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names) TypeError: _jit_pass_onnx_set_dynamic_input_shape(): incompatible function arguments. The following argument types are supported:

  1. (arg0: torch::jit::Graph, arg1: Dict[str, Dict[int, str]], arg2: List[str]) -> None