xiangking / ark-nlp

A private nlp coding package, which quickly implements the SOTA solutions.
Apache License 2.0
310 stars 65 forks source link

Bert模型参数维度对不上 #74

Closed 943685519 closed 2 years ago

943685519 commented 2 years ago

我用ark_nlp里面的 from ark_nlp.model.tc.bert import Bert训练的模型保存之后,用transformers里的BertModel去load_state_dict模型,结果发现参数维度对不上,报错如下: size mismatch for pooler.dense.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([768]).

PS:而且还发现了用ark_nlp里的Bert训练完后会比transformers里的BertModel少两个参数,分别是:"classifier.weight"和 "classifier.bias"

xiangking commented 2 years ago

您好,ark-nlp里的bert和bertModel不是等价的,我们的bert其实是BertModel+分类层,所以不能只用bert Model加载

943685519 commented 2 years ago

您好,ark-nlp里的bert和bertModel不是等价的,我们的bert其实是BertModel+分类层,所以不能只用bert Model加载

嗯嗯 我发现了 我用transformers的BertForSequenceClassification加载就ok了

xiangking commented 2 years ago

好的,那我关闭issue了,后续有什么问题也可以继续提issue提问