VIS-VAR / LGSC-for-FAS

Learning Generalized Spoof Cues for FaceAnti-spoofing
MIT License
226 stars 56 forks source link

TypeError: can't pickle paddle.fluid.core_avx.BlockDesc objects #2

Open ysm022 opened 4 years ago

ysm022 commented 4 years ago

您好,我下载torch模型,放到对应目录下./pretrained/resnet18-5c106cde.pth,运行convert.py,报错,报错语句fluid.save_dygraph(paddle_model.state_dict(), './pretrained/resnet18-torch'),我在util/util.py中第9句添加with fluid.dygraph.guard():,函数体后面语句全部放在with里面了,报错内容如标题所示,如何修改呢?谢谢

VIS-VAR commented 4 years ago

Hi, we have debugged this issue, please pull the new code (bab20b275b580e7fb01ac0b7bcdca53a42abae14).

ysm022 commented 4 years ago

服务器cuda10.0,下载paddlepaddle版本paddlepaddle-gpu 1.8.1.post107

在util/util.py文件中,添加with,报错如标题,如果不添加with,原始代码运行,报错为AssertionError: We Only support save_dygraph in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative Mode

修改代码如下:

def torch_weight_to_paddle_model(torch_weight_file, paddle_model): torch_weight = torch.load(torch_weight_file) with fluid.dygraph.guard(): fluid.dygraph.save_dygraph(paddle_model.state_dict(), './pretrained/resnet18-torch') paddleweight, = fluid.load_dygraph('./pretrained/resnet18-torch') for k, p in torch_weight.items(): if k in paddle_weight: np_parm = torch_weight[k].detach().numpy() if np_parm.shape == paddle_weight[k].shape: paddle_weight[k] = np_parm else: print('torch parm {} dose not match paddle parm {}'.format(k, k)) elif 'running_mean' in k: np_parm = torch_weight[k].detach().numpy() if np_parm.shape == paddle_weight[k[:-12]+'_mean'].shape: paddle_weight[k[:-12]+'_mean'] = np_parm else: print('torch parm {} dose not match paddle parm {}'.format(k, k[:-12]+'_mean')) elif 'running_var' in k: np_parm = torch_weight[k].detach().numpy() if np_parm.shape == paddle_weight[k[:-11] + '_variance'].shape: paddle_weight[k[:-11] + '_variance'] = np_parm else: print('torch parm {} dose not match paddle parm {}'.format(k, k[:-11] + '_variance')) else: print('torch parm {} not exist in paddle modle'.format(k)) paddle_model.set_dict(paddle_weight) fluid.dygraph.save_dygraph(paddle_model.state_dict(), './pretrained/resnet18-torch')