Open shining-love opened 4 years ago
用readme里的代码转换:
def torch_export(model, save_path):
model.eval()
data = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, data)
traced_script_module.save(save_path)
print("export finish.")
用readme里的代码转换:
def torch_export(model, save_path): model.eval() data = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, data) traced_script_module.save(save_path) print("export finish.")
def main(model_path, backbone, scale, path, save_path, gpu_id): if os.path.exists(save_path): shutil.rmtree(save_path, ignore_errors=True) if not os.path.exists(save_path): os.makedirs(save_path) save_img_folder = os.path.join(save_path, 'img') if not os.path.exists(save_img_folder): os.makedirs(save_img_folder) save_txt_folder = os.path.join(save_path, 'result') if not os.path.exists(save_txt_folder): os.makedirs(save_txt_folder) img_paths = [os.path.join(path, x) for x in os.listdir(path)] net = PSENet(backbone=backbone, pretrained=False, result_num=config.n) model = Pytorch_model(model_path, net=net, scale=scale, gpu_id=gpu_id) total_frame = 0.0 total_time = 0.0 for img_path in tqdm(img_paths): img_name = os.path.basename(img_path).split('.')[0] save_name = os.path.join(save_txtfolder, 'res' + imgname + '.txt') , boxes_list, t = model.predict(img_path) total_frame += 1 total_time += t
# cv2.imwrite(os.path.join(save_img_folder, '{}.jpg'.format(img_name)), img)
np.savetxt(save_name, boxes_list.reshape(-1, 8), delimiter=',', fmt='%d')
print('fps:{}'.format(total_frame / total_time))
return save_txt_folder
请问model是这个wenmuzhou代码里面的这个model变量吗?大佬你的这个reademe里的代码是放在这个 model = Pytorch_model(model_path, net=net, scale=scale, gpu_id=gpu_id)后面吗?
是这个: net = PSENet(backbone=backbone, pretrained=False, result_num=config.n)
在Pytorch_model的初始化里有加载模型参数代码
或者单独建一个文件,把PSENet参数加载上去,导出
是这个: net = PSENet(backbone=backbone, pretrained=False, result_num=config.n)
在Pytorch_model的初始化里有加载模型参数代码
或者单独建一个文件,把PSENet参数加载上去,导出
好的谢谢大佬
大佬你好。最近在搞部署这一块。很幸运看到你的这个部署repo。大佬能把模型从pytorch代码转为torchscript完整的代码开源一下吗?小白感觉无从下手。感谢