Open lmq5294249 opened 1 year ago
你好 有具体错误可以贴一下么?感谢!
train.py文件save_network方法改成这样
def save_network(network, epoch_label):
save_filename = 'net_%s.pth'% epoch_label
save_path = os.path.join('./model',name,save_filename)
# torch.save(network.cpu().state_dict(), save_path)
# 上面注释的部分改成下面的
torch.save(network, save_path)
if torch.cuda.is_available():
network.cuda(gpu_ids[0])
然后新建一个py文件,内容如下(其中输入模型和输出模型路径改成自己的):
import torch
import torch.nn
import onnx
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
# 路径改成训练输出模型的位置
model = torch.load('/project/train/src_repo/model/ft_ResNet50/net_last.pth', map_location=device)
model.eval()
input_names = ['input']
output_names = ['output']
x = torch.randn(1, 3, 224, 224, device=device)
# 路径改为转换onnx模型的位置
torch.onnx.export(model, x, '/project/train/src_repo/model/ft_ResNet50/net_last.onnx', input_names=input_names, output_names=output_names, verbose='True')
我尝试将模型转为onnx出现错误,无法解决