jamycheung / DELIVER

Repository of DELIVER dataset and CMNeXt models (CVPR 2023)
Apache License 2.0
129 stars 6 forks source link

导出onnx无法使用 #14

Closed 1404561326521 closed 2 months ago

1404561326521 commented 2 months ago

作者您好,我在将模型导出onnx时或提示警告: image 导致最终导出的onnx模型不正确, image 这是我的转换脚本:`import torch from semseg.models.cmnext import CMNeXt

model = CMNeXt(backbone='CMNeXt-B2', num_classes=9, modals=['img', 'depth'])

checkpoint = torch.load(r'D:\Segmentation_2D\DELIVER-RGBD-ToothSeg\output\DELIVER_CMNeXt-B2_id\CMNeXt_CMNeXt-B2_DELIVER_epoch410_71.46.pth')

if "module" in list(checkpoint.keys())[0]: new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()} model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint) model.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device)

img_shape = (1, 3, 512, 512)
depth_shape = (1, 3, 512, 512)

img_input = torch.randn(img_shape, device=device) depth_input = torch.randn(depth_shape, device=device)

input_names = ["ImgInput", "DepthInput"] dynamic_axes = {'ImgInput': {0: 'batch_size', 2: 'height', 3: 'width'}, 'DepthInput': {0: 'batch_size', 2: 'height', 3: 'width'}}

model_path = r'D:\Segmentation_2D\DELIVER-RGBD-ToothSeg\output\DELIVER_CMNeXt-B2_id\CMNeXt_CMNeXt-B2_DELIVER_epoch410_71.46.onnx'

torch.onnx.export(model, [img_input, depth_input], model_path, input_names=input_names, dynamic_axes=dynamic_axes, verbose=True, opset_version=11)`

jamycheung commented 2 months ago

Hi, thanks for your interest.

We did not try to export the onnx format checkpoint. It looks like shape is inconsistent. It might be very good if you can test and let us know.