Whu-wxy / PSENet-libtorch

Text detection network psenet deployed by libtorch and Qt.
14 stars 7 forks source link

大佬,我求助你一下 #3

Open shining-love opened 4 years ago

shining-love commented 4 years ago

大佬你好。最近在搞部署这一块。很幸运看到你的这个部署repo。大佬能把模型从pytorch代码转为torchscript完整的代码开源一下吗?小白感觉无从下手。感谢

Whu-wxy commented 4 years ago

用readme里的代码转换:

def torch_export(model, save_path):
  model.eval()
  data = torch.rand(1, 3, 224, 224)
  traced_script_module = torch.jit.trace(model, data)
  traced_script_module.save(save_path)
  print("export finish.")
shining-love commented 4 years ago

用readme里的代码转换:

def torch_export(model, save_path):
  model.eval()
  data = torch.rand(1, 3, 224, 224)
  traced_script_module = torch.jit.trace(model, data)
  traced_script_module.save(save_path)
  print("export finish.")

def main(model_path, backbone, scale, path, save_path, gpu_id): if os.path.exists(save_path): shutil.rmtree(save_path, ignore_errors=True) if not os.path.exists(save_path): os.makedirs(save_path) save_img_folder = os.path.join(save_path, 'img') if not os.path.exists(save_img_folder): os.makedirs(save_img_folder) save_txt_folder = os.path.join(save_path, 'result') if not os.path.exists(save_txt_folder): os.makedirs(save_txt_folder) img_paths = [os.path.join(path, x) for x in os.listdir(path)] net = PSENet(backbone=backbone, pretrained=False, result_num=config.n) model = Pytorch_model(model_path, net=net, scale=scale, gpu_id=gpu_id) total_frame = 0.0 total_time = 0.0 for img_path in tqdm(img_paths): img_name = os.path.basename(img_path).split('.')[0] save_name = os.path.join(save_txtfolder, 'res' + imgname + '.txt') , boxes_list, t = model.predict(img_path) total_frame += 1 total_time += t

img = draw_bbox(img_path, boxes_list, color=(0, 0, 255))

    # cv2.imwrite(os.path.join(save_img_folder, '{}.jpg'.format(img_name)), img)
    np.savetxt(save_name, boxes_list.reshape(-1, 8), delimiter=',', fmt='%d')
print('fps:{}'.format(total_frame / total_time))
return save_txt_folder

请问model是这个wenmuzhou代码里面的这个model变量吗?大佬你的这个reademe里的代码是放在这个 model = Pytorch_model(model_path, net=net, scale=scale, gpu_id=gpu_id)后面吗?

Whu-wxy commented 4 years ago

是这个: net = PSENet(backbone=backbone, pretrained=False, result_num=config.n)

在Pytorch_model的初始化里有加载模型参数代码

或者单独建一个文件,把PSENet参数加载上去,导出

shining-love commented 4 years ago

是这个: net = PSENet(backbone=backbone, pretrained=False, result_num=config.n)

在Pytorch_model的初始化里有加载模型参数代码

或者单独建一个文件,把PSENet参数加载上去,导出

好的谢谢大佬