layumi / Person_reID_baseline_pytorch

:bouncing_ball_person: Pytorch ReID: A tiny, friendly, strong pytorch implement of person re-id / vehicle re-id baseline. Tutorial 👉https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/tutorial
https://www.zdzheng.xyz
MIT License
4.14k stars 1.01k forks source link

能不能提供一个 onnx_export.py方便将模型转为onnx,以便再转为NCNN等框架可推理的模型 #363

Open lmq5294249 opened 1 year ago

lmq5294249 commented 1 year ago

我尝试将模型转为onnx出现错误,无法解决

layumi commented 1 year ago

你好 有具体错误可以贴一下么?感谢!

CaptainJi commented 1 year ago

train.py文件save_network方法改成这样

def save_network(network, epoch_label):
    save_filename = 'net_%s.pth'% epoch_label
    save_path = os.path.join('./model',name,save_filename)
    # torch.save(network.cpu().state_dict(), save_path)
    # 上面注释的部分改成下面的
    torch.save(network, save_path)
    if torch.cuda.is_available():
        network.cuda(gpu_ids[0])

然后新建一个py文件,内容如下(其中输入模型和输出模型路径改成自己的):

import torch
import torch.nn
import onnx

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
 # 路径改成训练输出模型的位置
model = torch.load('/project/train/src_repo/model/ft_ResNet50/net_last.pth', map_location=device)
model.eval()

input_names = ['input']
output_names = ['output']

x = torch.randn(1, 3, 224, 224, device=device)
 # 路径改为转换onnx模型的位置
torch.onnx.export(model, x, '/project/train/src_repo/model/ft_ResNet50/net_last.onnx', input_names=input_names, output_names=output_names, verbose='True')