Open ysm022 opened 4 years ago
Hi, we have debugged this issue, please pull the new code (bab20b275b580e7fb01ac0b7bcdca53a42abae14).
服务器cuda10.0,下载paddlepaddle版本paddlepaddle-gpu 1.8.1.post107
在util/util.py文件中,添加with,报错如标题,如果不添加with,原始代码运行,报错为AssertionError: We Only support save_dygraph in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative Mode
def torch_weight_to_paddle_model(torch_weight_file, paddle_model): torch_weight = torch.load(torch_weight_file) with fluid.dygraph.guard(): fluid.dygraph.save_dygraph(paddle_model.state_dict(), './pretrained/resnet18-torch') paddleweight, = fluid.load_dygraph('./pretrained/resnet18-torch') for k, p in torch_weight.items(): if k in paddle_weight: np_parm = torch_weight[k].detach().numpy() if np_parm.shape == paddle_weight[k].shape: paddle_weight[k] = np_parm else: print('torch parm {} dose not match paddle parm {}'.format(k, k)) elif 'running_mean' in k: np_parm = torch_weight[k].detach().numpy() if np_parm.shape == paddle_weight[k[:-12]+'_mean'].shape: paddle_weight[k[:-12]+'_mean'] = np_parm else: print('torch parm {} dose not match paddle parm {}'.format(k, k[:-12]+'_mean')) elif 'running_var' in k: np_parm = torch_weight[k].detach().numpy() if np_parm.shape == paddle_weight[k[:-11] + '_variance'].shape: paddle_weight[k[:-11] + '_variance'] = np_parm else: print('torch parm {} dose not match paddle parm {}'.format(k, k[:-11] + '_variance')) else: print('torch parm {} not exist in paddle modle'.format(k)) paddle_model.set_dict(paddle_weight) fluid.dygraph.save_dygraph(paddle_model.state_dict(), './pretrained/resnet18-torch')
您好,我下载torch模型,放到对应目录下./pretrained/resnet18-5c106cde.pth,运行convert.py,报错,报错语句fluid.save_dygraph(paddle_model.state_dict(), './pretrained/resnet18-torch'),我在util/util.py中第9句添加with fluid.dygraph.guard():,函数体后面语句全部放在with里面了,报错内容如标题所示,如何修改呢?谢谢