Closed thammegowda closed 4 years ago
I got what I was looking for; To anyone interested on this topic:
import torch
xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.large')
xlmr.eval()
def force_decode(sentence):
indices = xlmr.encode(sentence).view(1, -1) # [B=1 x T]
feats = xlmr.extract_features(indices) # [B=1 x T x D]
scores = xlmr.model.output_layer(feats) # [B=1 x T x V]
log_probs = torch.log_softmax(scores, dim=-1) # [B=1 x T x V]
lop_probs = log_probs.gather(dim=2, index=indices.unsqueeze(2)) # [B x T x 1]
log_prob = lop_probs.squeeze(2).sum(dim=1) # [B=1 x T] -> [B=1]
return torch.exp(log_prob).item()
sents = ["this is a good sentence",
"*%&^( some *^(*&)(^&*(^% gibberish &^$^%$^&*&",
"c'est une phrase française",
"c'est une phrase (^&*(^% gibberish &^$^% française",
"ಇದು ಕನ್ನಡ ವಾಕ್ಯ",
"ಇದು ಕನ್ನಡ (^&*(^% gibberish &^$^% ವಾಕ್ಯ",
"यह एक हिंदी वाक्य है",
"यह एक हिंदी (^&*(^% gibberish &^$^% वाक्य है"]
for sent in sents:
print('%.6f' % force_decode(sent), sent)
0.996101 this is a good sentence
0.000000 *%&^( some *^(*&)(^&*(^% gibberish &^$^%$^&*&
0.998722 c'est une phrase française
0.929215 c'est une phrase (^&*(^% gibberish &^$^% française
0.999590 ಇದು ಕನ್ನಡ ವಾಕ್ಯ
0.914713 ಇದು ಕನ್ನಡ (^&*(^% gibberish &^$^% ವಾಕ್ಯ
0.999829 यह एक हिंदी वाक्य है
0.843531 यह एक हिंदी (^&*(^% gibberish &^$^% वाक्य है
For example:
setup
Test
Thanks,