facebookresearch / XLM

PyTorch original implementation of Cross-lingual Language Model Pretraining.
Other
2.87k stars 495 forks source link

Predict a masked word #361

Open ebilal opened 11 months ago

ebilal commented 11 months ago

How can I predict a masked word? The code below doesnt work, maybe because of XLMTokenizer

reloaded = torch.load('dumped/xlm_en/lavlwh2d6j/best-valid_en_mlm_ppl.pth')

model_params = AttrDict(reloaded['params'])

dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])

encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()

encoder.load_state_dict(reloaded['model'])

tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')

text = "The capital of France is <mask>."

tokens = tokenizer(text, return_tensors='pt')

with torch.no_grad():

    predictions, _= encoder(**tokens, mode='predict')