yhcc / BARTNER

215 stars 22 forks source link

Indexes shifting #1

Closed xinsu626 closed 3 years ago

xinsu626 commented 3 years ago

Hello, nice work! Sorry if I miss something. I have a question about the decoder's output in your code.

Based on your code it seems you're shifting position indexes of the tokens by a number of labels. I was wondering why shift tokens instead of shift the labels. Thank you!

Please find a example below (results from here).

raw_words: ['SOCCER', '-', 'JAPAN', 'GET', 'LUCKY', 'WIN', ',', 'CHINA', 'IN', 'SURPRISE', 'DEFEAT', '.']
target_ids: [0, 11, 2, 20, 3, 1]    # each of the token positions is shifted by 6
word_bpe_ids: [0, 13910, 3376, 2076, 111, 344, 591, 1889, 7777, 226, 23806, 975, 17164, 2156, 3858, 16712, 2808, 31987, 4454, 18819, 5885, 10885, 2571, 479, 2]
word_bpe_tokens: ['<s>', 'ĠSO', 'CC', 'ER', 'Ġ-', 'ĠJ', 'AP', 'AN', 'ĠGET', 'ĠL', 'UCK', 'Y', 'ĠWIN', 'Ġ,', 'ĠCH', 'INA', 'ĠIN', 'ĠSUR', 'PR', 'ISE', 'ĠDE', 'FE', 'AT', 'Ġ.', '</s>']
yhcc commented 3 years ago

Because shifting labels will cause the shifting depending on the number of tokens in the input. And this will make the index of special tag vary in each input, which will cause the beam search algorithm hard to determine the finish code.

xinsu626 commented 3 years ago

Because shifting labels will cause the shifting depending on the number of tokens in the input. And this will make the index of special tag vary in each input, which will cause the beam search algorithm hard to determine the finish code.

Hi @yhcc , got it. Thanks for your reply! Sorry I have a follow-up question. Is it because you put the EOS token in the second position of the label space ([BOS, EOS, Tag1, ...]), so you set the EOS token id to 1 instead of BART's original 2 during generation (inference phase)?

yhcc commented 3 years ago

Yes. We map the eos id 1 to 2 in the forward function of our model (so that BART can still get proper eos token id).

xinsu626 commented 3 years ago

Yes. We map the eos id 1 to 2 in the forward function of our model (so that BART can still get proper eos token id).

@yhcc Thank you! This is really helpful.