thunlp / OpenMatch

An Open-Source Package for Information Retrieval.
MIT License
448 stars 42 forks source link

inference.py 报错 #9

Closed smart022 closed 3 years ago

smart022 commented 3 years ago

你好我在复现MS MARCO Passage Ranking第一个bert-base 模型的时候,按照给定的sh参数,下载了需要使用checkpoints/bert-base.bin, 但似乎在加载这个bert-base.bin模型参数的时候报错了,具体信息如下,提示缺少_model.embeddings.position_ids

reading test data...
Traceback (most recent call last):
  File "inference.py", line 201, in <module>
    main()
  File "inference.py", line 187, in main
    model.load_state_dict(st)
  File "/home/xxx/anaconda3/envs/tmp_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1044, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Bert:
    Missing key(s) in state_dict: "_model.embeddings.position_ids". 

我参考了inference.py的源码,定位在下面

    state_dict = torch.load(args.checkpoint)
    if args.model == 'bert':
        st = {}
        for k in state_dict:
            if k.startswith('bert'):
                st['_model'+k[len('bert'):]] = state_dict[k]
            elif k.startswith('classifier'):
                st['_dense'+k[len('classifier'):]] = state_dict[k]
            else:
                st[k] = state_dict[k]
        model.load_state_dict(st) ## 这里
    else:
        model.load_state_dict(state_dict)

我把上述代码注释后inference.py可正常运行,但使用的就是未经过微调的模型了。 请问这个bert-base.bin怎么能加载上?

smart022 commented 3 years ago

ok,我自己解决了 参考 Bert Checkpoint Breaks 3.02 -> 3.1.0 due to new buffer in BertEmbeddings 应该是Transfomers的版本问题。 具体解决方法是在model.load_state_dict( )的时候指定参数strict=False即可。