eva-n27 / BERT-for-Chinese-Question-Answering

Apache License 2.0
78 stars 18 forks source link

想问下怎么解决无法load模型的错误 #3

Closed tomtang110 closed 5 years ago

tomtang110 commented 5 years ago

2019-02-09 18 50 11 看了pytorch原作者的解释是去掉bert,但是去掉以后还是报错。 2019-02-09 18 51 42 代码错误如下 2019-02-09 18 52 32

eva-n27 commented 5 years ago

可否提供一下从预训练的模型文件中读取到的state dict中键的信息,例如如下代码

state_dict = torch.load(PATH_OF_YOUR_MODEL_FILE, map_location='cpu')
for key in state_dict.keys():
    print(key)
tomtang110 commented 5 years ago

可否提供一下从预训练的模型文件中读取到的state dict中键的信息,例如如下代码

state_dict = torch.load(PATH_OF_YOUR_MODEL_FILE, map_location='cpu')
for key in state_dict.keys():
    print(key)

bert.embeddings.word_embeddings.weight bert.embeddings.position_embeddings.weight bert.embeddings.token_type_embeddings.weight bert.embeddings.LayerNorm.weight bert.embeddings.LayerNorm.bias bert.encoder.layer.0.attention.self.query.weight bert.encoder.layer.0.attention.self.query.bias bert.encoder.layer.0.attention.self.key.weight bert.encoder.layer.0.attention.self.key.bias bert.encoder.layer.0.attention.self.value.weight bert.encoder.layer.0.attention.self.value.bias bert.encoder.layer.0.attention.output.dense.weight bert.encoder.layer.0.attention.output.dense.bias bert.encoder.layer.0.attention.output.LayerNorm.weight bert.encoder.layer.0.attention.output.LayerNorm.bias bert.encoder.layer.0.intermediate.dense.weight bert.encoder.layer.0.intermediate.dense.bias bert.encoder.layer.0.output.dense.weight bert.encoder.layer.0.output.dense.bias bert.encoder.layer.0.output.LayerNorm.weight bert.encoder.layer.0.output.LayerNorm.bias bert.encoder.layer.1.attention.self.query.weight bert.encoder.layer.1.attention.self.query.bias bert.encoder.layer.1.attention.self.key.weight bert.encoder.layer.1.attention.self.key.bias bert.encoder.layer.1.attention.self.value.weight bert.encoder.layer.1.attention.self.value.bias bert.encoder.layer.1.attention.output.dense.weight bert.encoder.layer.1.attention.output.dense.bias bert.encoder.layer.1.attention.output.LayerNorm.weight bert.encoder.layer.1.attention.output.LayerNorm.bias bert.encoder.layer.1.intermediate.dense.weight bert.encoder.layer.1.intermediate.dense.bias bert.encoder.layer.1.output.dense.weight bert.encoder.layer.1.output.dense.bias bert.encoder.layer.1.output.LayerNorm.weight bert.encoder.layer.1.output.LayerNorm.bias bert.encoder.layer.2.attention.self.query.weight bert.encoder.layer.2.attention.self.query.bias bert.encoder.layer.2.attention.self.key.weight bert.encoder.layer.2.attention.self.key.bias bert.encoder.layer.2.attention.self.value.weight bert.encoder.layer.2.attention.self.value.bias bert.encoder.layer.2.attention.output.dense.weight bert.encoder.layer.2.attention.output.dense.bias bert.encoder.layer.2.attention.output.LayerNorm.weight bert.encoder.layer.2.attention.output.LayerNorm.bias bert.encoder.layer.2.intermediate.dense.weight bert.encoder.layer.2.intermediate.dense.bias bert.encoder.layer.2.output.dense.weight bert.encoder.layer.2.output.dense.bias bert.encoder.layer.2.output.LayerNorm.weight bert.encoder.layer.2.output.LayerNorm.bias bert.encoder.layer.3.attention.self.query.weight bert.encoder.layer.3.attention.self.query.bias bert.encoder.layer.3.attention.self.key.weight bert.encoder.layer.3.attention.self.key.bias bert.encoder.layer.3.attention.self.value.weight bert.encoder.layer.3.attention.self.value.bias bert.encoder.layer.3.attention.output.dense.weight bert.encoder.layer.3.attention.output.dense.bias bert.encoder.layer.3.attention.output.LayerNorm.weight bert.encoder.layer.3.attention.output.LayerNorm.bias bert.encoder.layer.3.intermediate.dense.weight bert.encoder.layer.3.intermediate.dense.bias bert.encoder.layer.3.output.dense.weight bert.encoder.layer.3.output.dense.bias bert.encoder.layer.3.output.LayerNorm.weight bert.encoder.layer.3.output.LayerNorm.bias bert.encoder.layer.4.attention.self.query.weight bert.encoder.layer.4.attention.self.query.bias bert.encoder.layer.4.attention.self.key.weight bert.encoder.layer.4.attention.self.key.bias bert.encoder.layer.4.attention.self.value.weight bert.encoder.layer.4.attention.self.value.bias bert.encoder.layer.4.attention.output.dense.weight bert.encoder.layer.4.attention.output.dense.bias bert.encoder.layer.4.attention.output.LayerNorm.weight bert.encoder.layer.4.attention.output.LayerNorm.bias bert.encoder.layer.4.intermediate.dense.weight bert.encoder.layer.4.intermediate.dense.bias bert.encoder.layer.4.output.dense.weight bert.encoder.layer.4.output.dense.bias bert.encoder.layer.4.output.LayerNorm.weight bert.encoder.layer.4.output.LayerNorm.bias bert.encoder.layer.5.attention.self.query.weight bert.encoder.layer.5.attention.self.query.bias bert.encoder.layer.5.attention.self.key.weight bert.encoder.layer.5.attention.self.key.bias bert.encoder.layer.5.attention.self.value.weight bert.encoder.layer.5.attention.self.value.bias bert.encoder.layer.5.attention.output.dense.weight bert.encoder.layer.5.attention.output.dense.bias bert.encoder.layer.5.attention.output.LayerNorm.weight bert.encoder.layer.5.attention.output.LayerNorm.bias bert.encoder.layer.5.intermediate.dense.weight bert.encoder.layer.5.intermediate.dense.bias bert.encoder.layer.5.output.dense.weight bert.encoder.layer.5.output.dense.bias bert.encoder.layer.5.output.LayerNorm.weight bert.encoder.layer.5.output.LayerNorm.bias bert.encoder.layer.6.attention.self.query.weight bert.encoder.layer.6.attention.self.query.bias bert.encoder.layer.6.attention.self.key.weight bert.encoder.layer.6.attention.self.key.bias bert.encoder.layer.6.attention.self.value.weight bert.encoder.layer.6.attention.self.value.bias bert.encoder.layer.6.attention.output.dense.weight bert.encoder.layer.6.attention.output.dense.bias bert.encoder.layer.6.attention.output.LayerNorm.weight bert.encoder.layer.6.attention.output.LayerNorm.bias bert.encoder.layer.6.intermediate.dense.weight bert.encoder.layer.6.intermediate.dense.bias bert.encoder.layer.6.output.dense.weight bert.encoder.layer.6.output.dense.bias bert.encoder.layer.6.output.LayerNorm.weight bert.encoder.layer.6.output.LayerNorm.bias bert.encoder.layer.7.attention.self.query.weight bert.encoder.layer.7.attention.self.query.bias bert.encoder.layer.7.attention.self.key.weight bert.encoder.layer.7.attention.self.key.bias bert.encoder.layer.7.attention.self.value.weight bert.encoder.layer.7.attention.self.value.bias bert.encoder.layer.7.attention.output.dense.weight bert.encoder.layer.7.attention.output.dense.bias bert.encoder.layer.7.attention.output.LayerNorm.weight bert.encoder.layer.7.attention.output.LayerNorm.bias bert.encoder.layer.7.intermediate.dense.weight bert.encoder.layer.7.intermediate.dense.bias bert.encoder.layer.7.output.dense.weight bert.encoder.layer.7.output.dense.bias bert.encoder.layer.7.output.LayerNorm.weight bert.encoder.layer.7.output.LayerNorm.bias bert.encoder.layer.8.attention.self.query.weight bert.encoder.layer.8.attention.self.query.bias bert.encoder.layer.8.attention.self.key.weight bert.encoder.layer.8.attention.self.key.bias bert.encoder.layer.8.attention.self.value.weight bert.encoder.layer.8.attention.self.value.bias bert.encoder.layer.8.attention.output.dense.weight bert.encoder.layer.8.attention.output.dense.bias bert.encoder.layer.8.attention.output.LayerNorm.weight bert.encoder.layer.8.attention.output.LayerNorm.bias bert.encoder.layer.8.intermediate.dense.weight bert.encoder.layer.8.intermediate.dense.bias bert.encoder.layer.8.output.dense.weight bert.encoder.layer.8.output.dense.bias bert.encoder.layer.8.output.LayerNorm.weight bert.encoder.layer.8.output.LayerNorm.bias bert.encoder.layer.9.attention.self.query.weight bert.encoder.layer.9.attention.self.query.bias bert.encoder.layer.9.attention.self.key.weight bert.encoder.layer.9.attention.self.key.bias bert.encoder.layer.9.attention.self.value.weight bert.encoder.layer.9.attention.self.value.bias bert.encoder.layer.9.attention.output.dense.weight bert.encoder.layer.9.attention.output.dense.bias bert.encoder.layer.9.attention.output.LayerNorm.weight bert.encoder.layer.9.attention.output.LayerNorm.bias bert.encoder.layer.9.intermediate.dense.weight bert.encoder.layer.9.intermediate.dense.bias bert.encoder.layer.9.output.dense.weight bert.encoder.layer.9.output.dense.bias bert.encoder.layer.9.output.LayerNorm.weight bert.encoder.layer.9.output.LayerNorm.bias bert.encoder.layer.10.attention.self.query.weight bert.encoder.layer.10.attention.self.query.bias bert.encoder.layer.10.attention.self.key.weight bert.encoder.layer.10.attention.self.key.bias bert.encoder.layer.10.attention.self.value.weight bert.encoder.layer.10.attention.self.value.bias bert.encoder.layer.10.attention.output.dense.weight bert.encoder.layer.10.attention.output.dense.bias bert.encoder.layer.10.attention.output.LayerNorm.weight bert.encoder.layer.10.attention.output.LayerNorm.bias bert.encoder.layer.10.intermediate.dense.weight bert.encoder.layer.10.intermediate.dense.bias bert.encoder.layer.10.output.dense.weight bert.encoder.layer.10.output.dense.bias bert.encoder.layer.10.output.LayerNorm.weight bert.encoder.layer.10.output.LayerNorm.bias bert.encoder.layer.11.attention.self.query.weight bert.encoder.layer.11.attention.self.query.bias bert.encoder.layer.11.attention.self.key.weight bert.encoder.layer.11.attention.self.key.bias bert.encoder.layer.11.attention.self.value.weight bert.encoder.layer.11.attention.self.value.bias bert.encoder.layer.11.attention.output.dense.weight bert.encoder.layer.11.attention.output.dense.bias bert.encoder.layer.11.attention.output.LayerNorm.weight bert.encoder.layer.11.attention.output.LayerNorm.bias bert.encoder.layer.11.intermediate.dense.weight bert.encoder.layer.11.intermediate.dense.bias bert.encoder.layer.11.output.dense.weight bert.encoder.layer.11.output.dense.bias bert.encoder.layer.11.output.LayerNorm.weight bert.encoder.layer.11.output.LayerNorm.bias bert.pooler.dense.weight bert.pooler.dense.bias cls.predictions.bias cls.predictions.transform.dense.weight cls.predictions.transform.dense.bias cls.predictions.transform.LayerNorm.weight cls.predictions.transform.LayerNorm.bias cls.predictions.decoder.weight cls.seq_relationship.weight cls.seq_relationship.bias 打印如上图。如果方便的话,希望可以加个微信交流下,我的微信是:tangxiaoche1

eva-n27 commented 5 years ago

两种解决方法:

1.用如下方法读取参数的时候

model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))

使用run_squad.py中962-966行的方法将bert的参数读进来,例如将bert.embeddings.word_embeddings.weight变成embeddings.word_embeddings.weight,然后将以“cls”开头的参数丢弃。

2.用这种方式的时候

mode.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))

以某种方式初始化(如随机初始化)qa_output的参数,并存放在你创建的state_dict中,同样将以“cls”开头的参数丢弃。

如有任何问题,欢迎随时提issue或者邮件交流xiang.zhengpeng@gmail.com。