MuQiuJun-AI / bert4pytorch

超轻量级bert的pytorch版本,大量中文注释,容易修改结构,持续更新
408 stars 68 forks source link

返回 embedding 和 huggingface 的返回结果不完全一致 #8

Open mmmwhy opened 2 years ago

mmmwhy commented 2 years ago

比如 bert-base-chinese,作者是否有做过这方面的评估测试呀~

mmmwhy commented 2 years ago

sentence = "我是一个好男人!" max_len = 32 已设置 .eval

huggingface 结果

image

bert4pytorch 结果

image

mmmwhy commented 2 years ago

原始版本

from transformers import BertModel
from transformers import BertTokenizer

sentence = "我是一个好男人!"
max_len = 32

bert_model = BertModel.from_pretrained("/bert-base-chinese")
bert_model.eval()

text_tokenizer = BertTokenizer.from_pretrained("/bert-base-chinese", do_lower_case=True)
tensor_caption = text_tokenizer.encode(sentence, 
                return_tensors="pt",
                padding='max_length',
                truncation=True,max_length=max_len)

pooler_output = bert_model(tensor_caption).pooler_output
last_hidden_state = bert_model(tensor_caption).last_hidden_state

bert4pytorch 版本

import torch
from bert4pytorch.modeling import build_transformer_model
from bert4pytorch.tokenization import Tokenizer

sentence = "我是一个好男人!"
max_len = 32

root_model_path = "/bert-base-chinese"
vocab_path = root_model_path + "/vocab.txt"
config_path = root_model_path + "/config.json"
checkpoint_path = root_model_path + '/pytorch_model.bin'

# 建立分词器
tokenizer = Tokenizer(vocab_path)

# 读取数据
tokens_ids, segments_ids = tokenizer.encode(sentence, max_len=max_len)
tokens_ids = tokens_ids + (max_len - len(tokens_ids)) * [0]
segments_ids = segments_ids + (max_len - len(segments_ids)) * [0]
tokens_ids_tensor = torch.tensor([tokens_ids])
segment_ids_tensor = torch.tensor([segments_ids])

model = build_transformer_model(config_path, checkpoint_path, with_pool=True)
model.eval()

encoded_layers, pooled_output = model(tokens_ids_tensor, segment_ids_tensor)
Tongjilibo commented 2 years ago

试过把transformer中max_length这个入参去掉,两者是一致的

DimariaW commented 2 years ago

经过我的调试,这个问题最终定位是hugging face 的模型中对layerNorm参数的命名是"gamma"和“beta”。 但是作者导入参数时写的mapping是weight和bias,因此参数导入失败