if "module" in list(checkpoint.keys())[0]:
new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
作者您好,我在将模型导出onnx时或提示警告:
导致最终导出的onnx模型不正确,
这是我的转换脚本:`import torch
from semseg.models.cmnext import CMNeXt
model = CMNeXt(backbone='CMNeXt-B2', num_classes=9, modals=['img', 'depth'])
checkpoint = torch.load(r'D:\Segmentation_2D\DELIVER-RGBD-ToothSeg\output\DELIVER_CMNeXt-B2_id\CMNeXt_CMNeXt-B2_DELIVER_epoch410_71.46.pth')
if "module" in list(checkpoint.keys())[0]: new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()} model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint) model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device)
img_shape = (1, 3, 512, 512)
depth_shape = (1, 3, 512, 512)
img_input = torch.randn(img_shape, device=device) depth_input = torch.randn(depth_shape, device=device)
input_names = ["ImgInput", "DepthInput"] dynamic_axes = {'ImgInput': {0: 'batch_size', 2: 'height', 3: 'width'}, 'DepthInput': {0: 'batch_size', 2: 'height', 3: 'width'}}
model_path = r'D:\Segmentation_2D\DELIVER-RGBD-ToothSeg\output\DELIVER_CMNeXt-B2_id\CMNeXt_CMNeXt-B2_DELIVER_epoch410_71.46.onnx'
torch.onnx.export(model, [img_input, depth_input], model_path, input_names=input_names, dynamic_axes=dynamic_axes, verbose=True, opset_version=11)`