Closed tomtang110 closed 5 years ago
可否提供一下从预训练的模型文件中读取到的state dict中键的信息,例如如下代码
state_dict = torch.load(PATH_OF_YOUR_MODEL_FILE, map_location='cpu')
for key in state_dict.keys():
print(key)
可否提供一下从预训练的模型文件中读取到的state dict中键的信息,例如如下代码
state_dict = torch.load(PATH_OF_YOUR_MODEL_FILE, map_location='cpu') for key in state_dict.keys(): print(key)
bert.embeddings.word_embeddings.weight bert.embeddings.position_embeddings.weight bert.embeddings.token_type_embeddings.weight bert.embeddings.LayerNorm.weight bert.embeddings.LayerNorm.bias bert.encoder.layer.0.attention.self.query.weight bert.encoder.layer.0.attention.self.query.bias bert.encoder.layer.0.attention.self.key.weight bert.encoder.layer.0.attention.self.key.bias bert.encoder.layer.0.attention.self.value.weight bert.encoder.layer.0.attention.self.value.bias bert.encoder.layer.0.attention.output.dense.weight bert.encoder.layer.0.attention.output.dense.bias bert.encoder.layer.0.attention.output.LayerNorm.weight bert.encoder.layer.0.attention.output.LayerNorm.bias bert.encoder.layer.0.intermediate.dense.weight bert.encoder.layer.0.intermediate.dense.bias bert.encoder.layer.0.output.dense.weight bert.encoder.layer.0.output.dense.bias bert.encoder.layer.0.output.LayerNorm.weight bert.encoder.layer.0.output.LayerNorm.bias bert.encoder.layer.1.attention.self.query.weight bert.encoder.layer.1.attention.self.query.bias bert.encoder.layer.1.attention.self.key.weight bert.encoder.layer.1.attention.self.key.bias bert.encoder.layer.1.attention.self.value.weight bert.encoder.layer.1.attention.self.value.bias bert.encoder.layer.1.attention.output.dense.weight bert.encoder.layer.1.attention.output.dense.bias bert.encoder.layer.1.attention.output.LayerNorm.weight bert.encoder.layer.1.attention.output.LayerNorm.bias bert.encoder.layer.1.intermediate.dense.weight bert.encoder.layer.1.intermediate.dense.bias bert.encoder.layer.1.output.dense.weight bert.encoder.layer.1.output.dense.bias bert.encoder.layer.1.output.LayerNorm.weight bert.encoder.layer.1.output.LayerNorm.bias bert.encoder.layer.2.attention.self.query.weight bert.encoder.layer.2.attention.self.query.bias bert.encoder.layer.2.attention.self.key.weight bert.encoder.layer.2.attention.self.key.bias bert.encoder.layer.2.attention.self.value.weight bert.encoder.layer.2.attention.self.value.bias bert.encoder.layer.2.attention.output.dense.weight bert.encoder.layer.2.attention.output.dense.bias bert.encoder.layer.2.attention.output.LayerNorm.weight bert.encoder.layer.2.attention.output.LayerNorm.bias bert.encoder.layer.2.intermediate.dense.weight bert.encoder.layer.2.intermediate.dense.bias bert.encoder.layer.2.output.dense.weight bert.encoder.layer.2.output.dense.bias bert.encoder.layer.2.output.LayerNorm.weight bert.encoder.layer.2.output.LayerNorm.bias bert.encoder.layer.3.attention.self.query.weight bert.encoder.layer.3.attention.self.query.bias bert.encoder.layer.3.attention.self.key.weight bert.encoder.layer.3.attention.self.key.bias bert.encoder.layer.3.attention.self.value.weight bert.encoder.layer.3.attention.self.value.bias bert.encoder.layer.3.attention.output.dense.weight bert.encoder.layer.3.attention.output.dense.bias bert.encoder.layer.3.attention.output.LayerNorm.weight bert.encoder.layer.3.attention.output.LayerNorm.bias bert.encoder.layer.3.intermediate.dense.weight bert.encoder.layer.3.intermediate.dense.bias bert.encoder.layer.3.output.dense.weight bert.encoder.layer.3.output.dense.bias bert.encoder.layer.3.output.LayerNorm.weight bert.encoder.layer.3.output.LayerNorm.bias bert.encoder.layer.4.attention.self.query.weight bert.encoder.layer.4.attention.self.query.bias bert.encoder.layer.4.attention.self.key.weight bert.encoder.layer.4.attention.self.key.bias bert.encoder.layer.4.attention.self.value.weight bert.encoder.layer.4.attention.self.value.bias bert.encoder.layer.4.attention.output.dense.weight bert.encoder.layer.4.attention.output.dense.bias bert.encoder.layer.4.attention.output.LayerNorm.weight bert.encoder.layer.4.attention.output.LayerNorm.bias bert.encoder.layer.4.intermediate.dense.weight bert.encoder.layer.4.intermediate.dense.bias bert.encoder.layer.4.output.dense.weight bert.encoder.layer.4.output.dense.bias bert.encoder.layer.4.output.LayerNorm.weight bert.encoder.layer.4.output.LayerNorm.bias bert.encoder.layer.5.attention.self.query.weight bert.encoder.layer.5.attention.self.query.bias bert.encoder.layer.5.attention.self.key.weight bert.encoder.layer.5.attention.self.key.bias bert.encoder.layer.5.attention.self.value.weight bert.encoder.layer.5.attention.self.value.bias bert.encoder.layer.5.attention.output.dense.weight bert.encoder.layer.5.attention.output.dense.bias bert.encoder.layer.5.attention.output.LayerNorm.weight bert.encoder.layer.5.attention.output.LayerNorm.bias bert.encoder.layer.5.intermediate.dense.weight bert.encoder.layer.5.intermediate.dense.bias bert.encoder.layer.5.output.dense.weight bert.encoder.layer.5.output.dense.bias bert.encoder.layer.5.output.LayerNorm.weight bert.encoder.layer.5.output.LayerNorm.bias bert.encoder.layer.6.attention.self.query.weight bert.encoder.layer.6.attention.self.query.bias bert.encoder.layer.6.attention.self.key.weight bert.encoder.layer.6.attention.self.key.bias bert.encoder.layer.6.attention.self.value.weight bert.encoder.layer.6.attention.self.value.bias bert.encoder.layer.6.attention.output.dense.weight bert.encoder.layer.6.attention.output.dense.bias bert.encoder.layer.6.attention.output.LayerNorm.weight bert.encoder.layer.6.attention.output.LayerNorm.bias bert.encoder.layer.6.intermediate.dense.weight bert.encoder.layer.6.intermediate.dense.bias bert.encoder.layer.6.output.dense.weight bert.encoder.layer.6.output.dense.bias bert.encoder.layer.6.output.LayerNorm.weight bert.encoder.layer.6.output.LayerNorm.bias bert.encoder.layer.7.attention.self.query.weight bert.encoder.layer.7.attention.self.query.bias bert.encoder.layer.7.attention.self.key.weight bert.encoder.layer.7.attention.self.key.bias bert.encoder.layer.7.attention.self.value.weight bert.encoder.layer.7.attention.self.value.bias bert.encoder.layer.7.attention.output.dense.weight bert.encoder.layer.7.attention.output.dense.bias bert.encoder.layer.7.attention.output.LayerNorm.weight bert.encoder.layer.7.attention.output.LayerNorm.bias bert.encoder.layer.7.intermediate.dense.weight bert.encoder.layer.7.intermediate.dense.bias bert.encoder.layer.7.output.dense.weight bert.encoder.layer.7.output.dense.bias bert.encoder.layer.7.output.LayerNorm.weight bert.encoder.layer.7.output.LayerNorm.bias bert.encoder.layer.8.attention.self.query.weight bert.encoder.layer.8.attention.self.query.bias bert.encoder.layer.8.attention.self.key.weight bert.encoder.layer.8.attention.self.key.bias bert.encoder.layer.8.attention.self.value.weight bert.encoder.layer.8.attention.self.value.bias bert.encoder.layer.8.attention.output.dense.weight bert.encoder.layer.8.attention.output.dense.bias bert.encoder.layer.8.attention.output.LayerNorm.weight bert.encoder.layer.8.attention.output.LayerNorm.bias bert.encoder.layer.8.intermediate.dense.weight bert.encoder.layer.8.intermediate.dense.bias bert.encoder.layer.8.output.dense.weight bert.encoder.layer.8.output.dense.bias bert.encoder.layer.8.output.LayerNorm.weight bert.encoder.layer.8.output.LayerNorm.bias bert.encoder.layer.9.attention.self.query.weight bert.encoder.layer.9.attention.self.query.bias bert.encoder.layer.9.attention.self.key.weight bert.encoder.layer.9.attention.self.key.bias bert.encoder.layer.9.attention.self.value.weight bert.encoder.layer.9.attention.self.value.bias bert.encoder.layer.9.attention.output.dense.weight bert.encoder.layer.9.attention.output.dense.bias bert.encoder.layer.9.attention.output.LayerNorm.weight bert.encoder.layer.9.attention.output.LayerNorm.bias bert.encoder.layer.9.intermediate.dense.weight bert.encoder.layer.9.intermediate.dense.bias bert.encoder.layer.9.output.dense.weight bert.encoder.layer.9.output.dense.bias bert.encoder.layer.9.output.LayerNorm.weight bert.encoder.layer.9.output.LayerNorm.bias bert.encoder.layer.10.attention.self.query.weight bert.encoder.layer.10.attention.self.query.bias bert.encoder.layer.10.attention.self.key.weight bert.encoder.layer.10.attention.self.key.bias bert.encoder.layer.10.attention.self.value.weight bert.encoder.layer.10.attention.self.value.bias bert.encoder.layer.10.attention.output.dense.weight bert.encoder.layer.10.attention.output.dense.bias bert.encoder.layer.10.attention.output.LayerNorm.weight bert.encoder.layer.10.attention.output.LayerNorm.bias bert.encoder.layer.10.intermediate.dense.weight bert.encoder.layer.10.intermediate.dense.bias bert.encoder.layer.10.output.dense.weight bert.encoder.layer.10.output.dense.bias bert.encoder.layer.10.output.LayerNorm.weight bert.encoder.layer.10.output.LayerNorm.bias bert.encoder.layer.11.attention.self.query.weight bert.encoder.layer.11.attention.self.query.bias bert.encoder.layer.11.attention.self.key.weight bert.encoder.layer.11.attention.self.key.bias bert.encoder.layer.11.attention.self.value.weight bert.encoder.layer.11.attention.self.value.bias bert.encoder.layer.11.attention.output.dense.weight bert.encoder.layer.11.attention.output.dense.bias bert.encoder.layer.11.attention.output.LayerNorm.weight bert.encoder.layer.11.attention.output.LayerNorm.bias bert.encoder.layer.11.intermediate.dense.weight bert.encoder.layer.11.intermediate.dense.bias bert.encoder.layer.11.output.dense.weight bert.encoder.layer.11.output.dense.bias bert.encoder.layer.11.output.LayerNorm.weight bert.encoder.layer.11.output.LayerNorm.bias bert.pooler.dense.weight bert.pooler.dense.bias cls.predictions.bias cls.predictions.transform.dense.weight cls.predictions.transform.dense.bias cls.predictions.transform.LayerNorm.weight cls.predictions.transform.LayerNorm.bias cls.predictions.decoder.weight cls.seq_relationship.weight cls.seq_relationship.bias 打印如上图。如果方便的话,希望可以加个微信交流下,我的微信是:tangxiaoche1
两种解决方法:
1.用如下方法读取参数的时候
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
使用run_squad.py中962-966行的方法将bert的参数读进来,例如将bert.embeddings.word_embeddings.weight变成embeddings.word_embeddings.weight,然后将以“cls”开头的参数丢弃。
2.用这种方式的时候
mode.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
以某种方式初始化(如随机初始化)qa_output的参数,并存放在你创建的state_dict中,同样将以“cls”开头的参数丢弃。
如有任何问题,欢迎随时提issue或者邮件交流xiang.zhengpeng@gmail.com。
看了pytorch原作者的解释是去掉bert,但是去掉以后还是报错。 代码错误如下