BramVanroy / bert-for-inference

A small repo showing how to easily use BERT (or other transformers) for inference
98 stars 21 forks source link

dim should be 0 instead of 1? #2

Open monk1337 opened 4 years ago

monk1337 commented 4 years ago

I am running your code and found there is one issue :

sentence_embedding = torch.mean(hidden_states[-1], dim=1).squeeze()

should be

sentence_embedding = torch.mean(hidden_states[-1], dim=0).squeeze()

Please check!

BramVanroy commented 4 years ago

I haven't used this notebook in forever, so I am not sure how it works with newer transformers versions. Running on 2.2.2, you get the following sizes for the resulting:

This is what you want, one vector for the whole sentence.

sentence_embedding = torch.mean(hidden_states[-1], dim=1).squeeze()
print(sentence_embedding.size())
# torch.Size([768])

What you suggest:

sentence_embedding = torch.mean(hidden_states[-1], dim=0).squeeze()
print(sentence_embedding.size())
# torch.Size([5, 768])

Doing the mean on dim 0 does not make sense here, because then sentence_embedding is the same vector as hidden_states[-1], but with one more dimension with one item (batch size).

monk1337 commented 4 years ago

I am using bert-large with transformers 3.0.0 :

from transformers import BertModel, BertConfig, BertTokenizer
sentence = 'checking the bert encoding with this sentence'

tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model     = BertModel.from_pretrained('bert-large-uncased')

input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)  # Batch size 1
outputs   = model(input_ids)

If I print outputs[1] as you have using in your code :

print(outputs[1].shape)
torch.Size([1, 1024])

But I print outputs[0] it's giving me hidden size :

print(outputs[0].shape)
torch.Size([1, 9, 1024])

Now If i apply your code on this output :

sentence_embedding = torch.mean(outputs[0][-1], dim=1).squeeze()
print(sentence_embedding.size())
torch.Size([9])

if I change the dim=0 then :

sentence_embedding = torch.mean(outputs[0][-1], dim=0).squeeze()
print(sentence_embedding.size())
torch.Size([1024])

Am I doing any mistake here? Please correct!

Also, you suggested one more thing to take last 4 hidden layers output and an average of them, How should do it on this output?

Thank you!