facebookresearch / XLM

PyTorch original implementation of Cross-lingual Language Model Pretraining.
Other
2.89k stars 498 forks source link

How to get perplexity of a sentence using a pretrained model? #272

Closed thammegowda closed 4 years ago

thammegowda commented 4 years ago

For example:

setup

import torch
xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.large')
xlmr.eval()

Test

sent1 = "this is a good sentence"
sent2 = "*%&^( some *^(*&)(^&*(^% gibberish &^$^%$^&*&"

perp1 = xlmr.what?(sent1)
perp2 = xlmr.what?(sent2)

print(perp1, perp2)

Thanks,

thammegowda commented 4 years ago

I got what I was looking for; To anyone interested on this topic:

import torch
xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.large')
xlmr.eval()

def force_decode(sentence):
    indices = xlmr.encode(sentence).view(1, -1)    # [B=1 x T]
    feats = xlmr.extract_features(indices)         # [B=1 x T x D]
    scores = xlmr.model.output_layer(feats)        # [B=1 x T x V]
    log_probs = torch.log_softmax(scores, dim=-1)  # [B=1 x T x V]
    lop_probs = log_probs.gather(dim=2, index=indices.unsqueeze(2)) # [B x T x 1]
    log_prob = lop_probs.squeeze(2).sum(dim=1)    # [B=1 x T] -> [B=1]
    return torch.exp(log_prob).item()

sents = ["this is a good sentence",
         "*%&^( some *^(*&)(^&*(^% gibberish &^$^%$^&*&",
         "c'est une phrase française",
         "c'est une phrase (^&*(^% gibberish &^$^% française",         
         "ಇದು ಕನ್ನಡ ವಾಕ್ಯ",
         "ಇದು ಕನ್ನಡ (^&*(^% gibberish &^$^% ವಾಕ್ಯ",
         "यह एक हिंदी वाक्य है",
         "यह एक हिंदी (^&*(^% gibberish &^$^% वाक्य है"]
for sent in sents:
    print('%.6f' % force_decode(sent), sent)
0.996101 this is a good sentence
0.000000 *%&^( some *^(*&)(^&*(^% gibberish &^$^%$^&*&
0.998722 c'est une phrase française
0.929215 c'est une phrase (^&*(^% gibberish &^$^% française
0.999590 ಇದು ಕನ್ನಡ ವಾಕ್ಯ
0.914713 ಇದು ಕನ್ನಡ (^&*(^% gibberish &^$^% ವಾಕ್ಯ
0.999829 यह एक हिंदी वाक्य है
0.843531 यह एक हिंदी (^&*(^% gibberish &^$^% वाक्य है