DLLXW / baby-llama2-chinese

用于从头预训练+SFT一个小参数量的中文LLaMa2的仓库;24G单卡即可运行得到一个具备简单中文问答能力的chat-llama2.
MIT License
2.47k stars 305 forks source link

为什么在pretrain 309行model.complie要加prefix '_orig_mod'? #34

Closed ToxicNeil closed 1 year ago

ToxicNeil commented 1 year ago

运行sft.py时会报错,state_dict无法对齐:

Initializing a new model from scratch
Traceback (most recent call last):
  File "sft.py", line 295, in <module>
    model.load_state_dict(torch.load('./out/20230915_baike_pretrain/epoch_0.pth'))
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Transformer:
        Missing key(s) in state_dict: "tok_embeddings.weight", "layers.0.attention.wq.weight", "layers.0.attention.wk.weight", "layers.0.attention.wv.weight"...

在sft.py 292行初始化模型之后,用resume里面的代码去掉prefix,报错解除

model = init_model()
pretrain_state_dict = torch.load('./out/baike_pretrain/epoch_0.pth')
unwanted_prefix = "_orig_mod."
for k, v in list(pretrain_state_dict.items()):
      if k.startswith(unwanted_prefix):
         pretrain_state_dict[k[len(unwanted_prefix):]] = pretrain_state_dict.pop(k)
model.load_state_dict(pretrain_state_dict)
ToxicNeil commented 1 year ago

已解决,仅提出疑惑