allenai / kb

KnowBert -- Knowledge Enhanced Contextual Word Representations
Apache License 2.0
371 stars 50 forks source link

Use KnowBert to predict missing words #11

Closed jzbjyb closed 4 years ago

jzbjyb commented 4 years ago

Hi Matthew,

Thanks a bunch for the documentation on embedding sentences programmatically. It saves me a lot of time! I did a little bit of modification so that I can use KnowBert to predict the missing word (i.e., [MASK]) in a sentence, but found the results are unexpected. I am not sure if my implementation is correct, here is code snippet:

from kb.include_all import ModelArchiveFromParams
from kb.knowbert_utils import KnowBertBatchifier
from allennlp.common import Params
import torch
import torch.nn.functional as F

archive_file = 'https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz'

# load model and batcher
params = Params({'archive_file': archive_file})
model = ModelArchiveFromParams.from_params(params=params)
model.eval()
batcher = KnowBertBatchifier(archive_file)

# get bert vocab
vocab = list(batcher.tokenizer_and_candidate_generator.bert_tokenizer.ids_to_tokens.values())

sentences = ['Paris is located in [MASK].']
mask_ind = 5

for batch in batcher.iter_batches(sentences, verbose=False):
    model_output = model(**batch)
    # the tokenized sentence, where the 6-th token is [MASK]
    print([vocab[w] for w in batch['tokens']['tokens'][0].numpy()])
    logits, _ = model.pretraining_heads(model_output['contextual_embeddings'], model_output['pooled_output'])
    log_probs = F.log_softmax(logits, dim=-1)
    topk = torch.topk(log_probs[0, mask_ind], 10, 0)[1]
    # print the top 10 predictions
    print([vocab[t.item()] for t in topk])

The top 10 predictions are [UNK], the, itself, its, and, marne, to, them, first, lissa, while the top 10 predictions of BERT-uncased-base is france, paris, europe, italy, belgium, algeria, germany, russia, haiti, canada, which seems a little bit wired. Is my implementation correct or any suggestions on this? Thanks in advance!

matt-peters commented 4 years ago

If you want to fill in [MASK] tokens then it's necessary initialize batcher = KnowBertBatchifier(archive_file, masking_strategy='full_mask'). This creates batches in the same was as during pretraining. After doing so, I get ['france', 'germany', 'belgium', 'europe', 'canada', 'italy', 'paris', 'spain', 'russia', 'algeria'] as the top 10 predictions for 'Paris is located in [MASK].'

jzbjyb commented 4 years ago

Thanks for your quick reply! I just noticed that full_mask is not the default and using it can make correct predictions!