Closed Mehrad0711 closed 3 years ago
Thanks for flagging. We are still trying to get to the bottom of which special tokens to use for mbart-large-cc25. See https://github.com/pytorch/fairseq/issues/2258 .
This might not be the best solution but after experimenting with the tokenizer special tokens a bit, it seems like the model is insensitive to the first input_id and lang_code used on the encoder side. So after these modifications:
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.prefix_tokens = [self.bos_token_id]
self.suffix_tokens = [self.eos_token_id]
def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
in tokenization_bart.py, the model seems to be doing the right thing and generating correct English output:
src_sent: UN Chief Says There Is No Military Solution in Syria
src_ids: {'input_ids': tensor([[ 0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485,
53, 187895, 23, 51712, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
output_ids: tensor([[250004, 0, 8274, 127873, 25916, 7, 8622, 2071, 438,
67485, 53, 187895, 23, 51712]])
output: UN Chief Says There Is No Military Solution in Syria
Although this is not how things are described in MBART paper so the core issue remains.
Also, this code segment in modeling_bart.py :
def adjust_logits_during_generation(self, logits, cur_len, max_length, **kwargs):
if cur_len == 1:
self._force_token_ids_generation(logits, self.config.bos_token_id)
is the culprit for generating 0 (corresponding to bos_token) at every first decoding step. It might need to be changed for MBART model.
I find that the input does need to start with <bos>
, and the decoder should be seeded with <lang_code> <bos>
. With this setup, I am able to recover the input sequence during decoding. Like @Mehrad0711 I find that the input lang_code
does not make a significant difference.
@tomhosking I replicated what you wrote. Definitely needs <s>
(bos) at the beginning of the input string to fix the off by one error. You can see the tests in this PR #6524 .
I'm still having trouble squaring this/trying to find a unified fix to accomodate the behavior that mbart-large-en-ro seems to want, as shown in #6156 .
Maybe the simplest change is just to add <s>
to the start of the encoder side string?
Interestingly, prepend_bos
is set to false by default in the fairseq mbart finetuning docs.
I set a breakpoint during finetuning and there is no BOS to be found: here is how batches look
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Hi,
Reviving this issue as it still persists in the recent version of transformers. The solution I proposed in Aug 2 comment seems to be working after a modification on prefix source token borrowed from tokenization_mbart50.py
.
Proposed solution:
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. Prefix=[src_lang_code] and suffix=[eos]."""
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. Prefix=[tgt_lang_code] and suffix=[eos]."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
This change will fix mbart-large-cc25
's output text while leaving mbart-large-en-ro
's untouched.
Code to reproduce:
from transformers import MBartTokenizer, MBartForConditionalGeneration
tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25')
model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25')
src_sent = "UN Chief Says There Is No Military Solution in Syria"
batch = tokenizer.prepare_seq2seq_batch(src_texts=[src_sent], src_lang="en_XX", return_tensors="pt")
output_ids = model.generate(**batch, decoder_start_token_id=tokenizer.lang_code_to_id["en_XX"])
output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
print('src_sent: ', src_sent)
print('src_ids: ', batch["input_ids"])
print('output_ids: ', output_ids)
print('output: ', output)
stdout (before change):
src_sent: UN Chief Says There Is No Military Solution in Syria
src_ids: tensor([[ 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53,
187895, 23, 51712, 2, 250004]])
output_ids: tensor([[250004, 0, 127873, 25916, 7, 8622, 2071, 438, 67485,
53, 187895, 23, 51712, 2]])
output: Chief Says There Is No Military Solution in Syria
stdout (after change):
src_sent: UN Chief Says There Is No Military Solution in Syria
src_ids: tensor([[250004, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485,
53, 187895, 23, 51712, 2]])
output_ids: tensor([[250004, 0, 8274, 127873, 25916, 7, 8622, 2071, 438,
67485, 53, 187895, 23, 51712, 2]])
output: UN Chief Says There Is No Military Solution in Syria
Potential reviewers: @patrickvonplaten, @patil-suraj, @sgugger
Hi @sgugger, @patrickvonplaten, @patil-suraj, it would be great if you could provide your feedback on this. I would be happy to provide more context if needed.
Hi @Mehrad0711
Thank you for reporting this. I'll go through the original model code in fairseq to see how they are handling the prefix tokens and get back here.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hi @patil-suraj, Was wondering if you had the chance to take a look at this. Thanks.
Hi @Mehrad0711
You are right, all mBART models actually use the language code as the prefix token and <eos>
as the suffix token.
But unfortunately, we can't really change it now, because this will be backward incompatible with the other models trained using the existing format.
Also, this doesn't really make that much difference if you want to fine-tune the model. As long as a consistent format is used for fine-tuning and then for inference then it should work. However, it would change the output for the pre-trained models (as you reported). But as mbart-large-cc25
is just a pre-trained model and should be fine-tuned to use it for the downstream tasks, this doesn't seem like a big issue.
Hi @patil-suraj!
Thanks for your reply. I understand the concern regarding backward incompatibility of this change.
I was using mbart-large-cc25
without fine-tuning for text denoising; that's how the problem popped up. Given newer mbart models are now available on huggingface, I'll switch to them.
🐛 Bug
Information
Model I am using (Bert, XLNet ...): MBART
Language I am using the model on (English, Chinese ...): English, Romanian
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
I'm examining 'facebook/mbart-large-en-ro' and 'facebook/mbart-large-cc25' checkpoints of MBART.
Here is my first script translating an English sentence to Romanian:
stdout:
As seen in output_ids the model always generates 0 (corresponding to bos_token) at the first decoding step. However, this seems to not be a problem with this checkpoint as the output is still the correct translation.
Now I run the same script but using pretrained "facebook/mbart-large-cc25" and trying to denoise an English input. Since the input does not have mask tokens the output should be identical to the input given the pertaining objective of MBART. However, the output always misses the first token from the input. I have observed this with different examples (even when you have masked tokens in the input).
the stdout:
I have tried various approaches but haven't found any clear solutions to this. Appreciate any help on this.
Environment info
transformers
version: 3.0.2