facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.43k stars 6.4k forks source link

How to get masked word prediction for other languages in XLM-Roberta #3437

Closed AnnaSou closed 2 years ago

AnnaSou commented 3 years ago

Hello,

I trying to get masked words predictions for languages except English with XLM Roberta.

import torch
from fairseq.models.roberta import XLMRModel
xlmr = XLMRModel.from_pretrained('xlmr.large', checkpoint_file='model.pt')
xlmr.eval()  # disable dropout (or leave in train mode to finetune)

English example worked somewhat fine.

print(xlmr.fill_mask('Cats like <mask>.', topk=10))

[('Cats like it.', 0.08111538738012314, ' it'), ('Cats like me.', 0.07621509581804276, ' me'), ('Cats like to.', 0.05113542079925537, ' to'), ('Cats like this.', 0.04705266281962395, ' this'), ('Cats like animals.', 0.04491560906171799, ' animals'), ('Cats like sex.', 0.029887961223721504, ' sex'), ('Cats like you.', 0.025108635425567627, ' you'), ('Cats like them.', 0.02068178914487362, ' them'), ('Cats like food.', 0.01644749753177166, ' food'), ('Cats like love.', 0.015357503667473793, ' love')]

Example in Chinese worked horribly:
print(xlmr.fill_mask('猫喜欢 <mask>.', topk=10))
[('猫喜欢... .', 0.10599668323993683, '...'), ('猫喜欢 ... .', 0.09807392209768295, ' ...'), ('猫喜欢. .', 0.03559407964348793, '.'), ('猫喜欢 : .', 0.030052676796913147, ' :'), ('猫喜欢...... .', 0.02641991153359413, '......'), ('猫喜欢! .', 0.025788016617298126, '!'), ('猫喜欢..... .', 0.02024935372173786, '.....'), ('猫喜欢: .', 0.016514059156179428, ':'), ('猫喜欢 ! .', 0.016426226124167442, ' !'), ('猫喜欢我 .', 0.016250096261501312, '我')]

Example in Russian worked a bit better than English, and much better than Chinese.
print(xlmr.fill_mask('Коты любят <mask>.', topk=10))
[('Коты любят музыку.', 0.06512077152729034, ' музыку'), ('Коты любят животных.', 0.05911742523312569, ' животных'), ('Коты любят детей.', 0.04838460311293602, ' детей'), ('Коты любят людей.', 0.048284295946359634, ' людей'), ('Коты любят собак.', 0.03620772063732147, ' собак'), ('Коты любят воду.', 0.035584960132837296, ' воду'), ('Коты любят секс.', 0.02180316112935543, ' секс'), ('Коты любят природу.', 0.021794842556118965, ' природу'), ('Коты любят огонь.', 0.018937349319458008, ' огонь'), ('Коты любят..', 0.01525605283677578, '.')]

Maybe I am doing something wrong. How to use multilingual XLM-Roberta for masked task? Ideally, I want to mask model with target words. For instance, xlmr.fill_mask('Cats likes <mask>.', targets=['sleeping']). So, that I would get probability for the word "sleeping". In addition, to do it for other languages.

Thanks! Anna

zhang-xi commented 3 years ago

I think you didn't do anything wrong because I tried other Chinese sentences, it worked fine.

I tried '能帮我一次下周二大约6点的棒球赛吗 ?', and got the following results: [('能帮我推荐一次下周二大约6点的棒球赛吗 ?', 0.26250752806663513, '推荐'), ('能帮我找一次下周二大约6点的棒球赛吗 ?', 0.17815212905406952, '找'), ('能帮我安排一次下周二大约6点的棒球赛吗 ?', 0.06453423947095871, '安排'), ('能帮我找到一次下周二大约6点的棒球赛吗 ?', 0.033767491579055786, '找到'), ('能帮我看一次下周二大约6点的棒球赛吗 ?', 0.03100642003118992, '看'), ('能帮我选一次下周二大约6点的棒球赛吗 ?', 0.029681149870157242, '选'), ('能帮我预约一次下周二大约6点的棒球赛吗 ?', 0.023714391514658928, '预约'), ('能帮我联系一次下周二大约6点的棒球赛吗 ?', 0.01841096580028534, '联系'), ('能帮我查一次下周二大约6点的棒球赛吗 ?', 0.014259709045290947, '查'), ('能帮我报一次下周二大约6点的棒球赛吗 ?', 0.013958051800727844, '报')]

Maybe '猫喜欢 .' this sentence is not trained well.

In order to get the probability of the specified target word, I think maybe you need to revise the function fill_mask in fairseq/models/roberta/hub_interface.py. Line 187 prob = logits.softmax(dim=0), prob is the probabilities for all the words in the dictionary. You may need to encode the target word using the dictionary first and get the probability by the encoded index.

Hope this can help you.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!

stale[bot] commented 2 years ago

Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!