huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.28k stars 26.1k forks source link

bug in model.generate() / beam_search score per token #16053

Closed rbbb closed 2 years ago

rbbb commented 2 years ago

Tried on transformers 4.16.2, not platform dependent (bug in the algo).

Copy/paste of an email sent to cwkeam:

I tried to generate some phrases with transformers (BART), and then tried to get the perplexity of each character, and noticed the characters scores were wrong.

If I had to give a half baked patch, it would look something like this:

generation_utils.py, line 2259
scores = None #torch.zeros((batch_size*num_beams, 0, ), dtype=torch.float, device=input_ids.device) if (return_dict_in_generate and output_scores) else None

generation_utils.py, line 2319
                #if output_scores:
                #    scores += (next_token_scores,)

generation_utils.py, line 2337, insert:
#this is necessary because the next_token_scores get sorted (apparently in place)
#and also the variable gets overwritten (by a softmax, etc...)
old_input_scores = next_token_scores.view(batch_size*num_beams, vocab_size).clone()

generation_utils.py, line 2374:
#I'm not sure the clone() is necessary
            if return_dict_in_generate and output_scores:
                scores = old_input_scores[beam_idx, :].unsqueeze(-2) if scores is None else [torch.cat](http://torch.cat/)( (scores[beam_idx, :, :].clone(), old_input_scores[beam_idx, :].unsqueeze(-2)), dim=-2)

This will give you proper scores for the beams that generated the output, not the beams that were present when the nth token was generated.

My pseudo-patch downsides:

I don't really know the coding standards for transformers, and am not very good with pytorch, but would you be interested in having such a patch ? The main point is token/phrase perplexity.

Copy paste of a notebook reproducing the bug

import os,sys
import re, itertools, collections
import pandas, numpy
import transformers, torch
import torch.nn as nn

#not model dependent
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartConfig
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

prompt = ['I went to a restaurant. '+review+'. The food was very <mask> in my opinion.' for review in ('The ambiance was bad',)]

#show the probability of the 5 most probable tokens at each token generated
tokens = tokenizer(prompt, return_tensors="pt", padding=True)
token_ids, attn_mask = tokens['input_ids'], tokens['attention_mask']
for i in range(2):
    out = bart_model.generate(token_ids, max_length=100, do_sample=True,top_k=500,temperature=1.5, num_beams=4, num_return_sequences=1, output_scores=True, return_dict_in_generate=True, output_hidden_states=True)
    outscores = torch.cat([out['scores'][i].unsqueeze(1) for i in range(len(out['scores']))], dim=1)
    print("best beam is seq: ", out['beam_indices'][0][-1].item())
    num_beams, num_chars, char_dim = outscores.shape
    num_top = 5
    print(out['beam_indices'])
    for s in range(num_beams):
        for c in range(num_chars):
            probas = nn.functional.softmax(outscores[s,c,:], dim=-1)
            top_proba, top_idx = probas.topk(num_top, dim=-1)
            num_seqs = top_proba.shape[0]
            print("seq"+str(s)+"char"+str(c),"---".join([ str("{:0.1f}".format(p.item()*100))+" "+tokenizer.decode(top_idx[i]) for i,p in enumerate(top_proba)]))

    print("out",tokenizer.batch_decode(out['sequences'], skip_special_tokens=True))

example output and analysis

out ['I went to a restaurant. The ambiance was bad. The food was very bland and unauthentic in my opinion.']
seq1char0 99.7 <s>---0.0 The---0.0 I---0.0 There---0.0 Each
seq1char1 99.9 I---0.0 </s>---0.0 We---0.0 It---0.0 You
seq1char2 100.0  went---0.0  ate---0.0  go---0.0  Went---0.0  came
seq1char3 100.0  to---0.0  a---0.0  on---0.0  in---0.0  at
seq1char4 100.0  a---0.0  A---0.0  this---0.0  an---0.0  the
seq1char5 100.0  restaurant---0.0  Chinese---0.0  Mexican---0.0  Thai---0.0  hotel
seq1char6 100.0 .---0.0 ,---0.0  where---0.0  in---0.0  that
seq1char7 99.9  The---0.0  It---0.0 The---0.0  the---0.0  I
seq1char8 100.0  amb---0.0  Amb---0.0  lighting---0.0  vibe---0.0  restaurant
seq1char9 100.0 iance---0.0 iances---0.0 ience---0.0 ation---0.0  amb
seq1char10 100.0  was---0.0  –---0.0  is---0.0 ,---0.0  were
seq1char11 100.0  bad---0.0  terrible---0.0  horrible---0.0  good---0.0  negative
seq1char12 100.0 .---0.0 ,---0.0 ;---0.0  and---0.0  -
seq1char13 99.8  The---0.0  It---0.0  the---0.0  I---0.0  There
seq1char14 99.9  food---0.0  restaurant---0.0  Food---0.0  music---0.0  wine
seq1char15 99.9  was---0.0  is---0.0 .---0.0 ,---0.0  in
seq1char16 99.8  very---0.0  not---0.0  terrible---0.0  bad---0.0  too
seq1char17 12.9  bad---10.6  good---3.2  poor---2.5  bland---2.3  mediocre
seq1char18 23.5 ,---15.2 .---12.5  in---9.3  and---7.6  but
seq1char19 69.1  my---3.0  the---1.1  its---0.9  a---0.8  general
seq1char20 15.4 app---7.8 interesting---6.2 ----5.1 inspired---3.4 original
seq1char21 93.2  my---0.5  the---0.3  their---0.2  its---0.2  a
seq1char22 99.7  opinion---0.0  view---0.0  favor---0.0  humble---0.0  taste
seq1char23 99.8  opinion---0.0  favor---0.0  view---0.0  taste---0.0  humble
seq1char24 99.9 .---0.0 ,---0.0  and---0.0 ;---0.0 ..
seq1char25 99.9 .---0.0 ,---0.0  and---0.0 ;---0.0 ..
seq1char26 100.0 .---0.0 ,---0.0 </s>---0.0 ;---0.0  and

You can see it reports generating 'opinion' twice, which is not the case in the output phrase.

This is because scores += (logits_warper(input_ids, next_token_scores_processed),) #generation_utils.py, line 2319 is not the probability of the beams after a fork . If say beam 0 has such high score that all beams get replaced by a copy of beam 0 with a new token each, then the scores update should be scores = scores[beam_0==0].repeat(1,4) + scores_for_new_tokens_added_to_beam_0

This is more or less what my pseudo-patch does correctly.

Tell me if you want a patch (I'm not ultra-proficient with pytorch but can make a first version) or want to take a look yourselves.

@cwkeam

cwkeam commented 2 years ago

(this really should be tagged to HF team @patrickvonplaten and I'm not sure if I'm at authority to answer but here's my reply)

So I think that the issue you're noticing is coming from the fact that scores is saved before the actual beam search process at step k.

If say beam 0 has such high score that all beams get replaced by a copy of beam 0 with a new token each, then the scores update should be
scores = scores[beam_0==0].repeat(1,4) +  #scores_for_new_tokens_added_to_beam_0

Though it wouldn't be implemented this way, what you're expressing is that you want scores to reflect the beam selection done by the BeamScorer. I believe so because If say beam 0 has such high score that all beams get replaced by a copy of beam 0 with a new token each is something that happens inside BeamScorer.process.

Currently the code goes:

outputs = # ... model outputs
# ...
next_token_scores_processed = logits_processor(input_ids, next_token_scores)

next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
    if output_scores:
        scores += (next_token_scores,)
    if output_attentions:
        decoder_attentions += (
            (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
        )
        if self.config.is_encoder_decoder:
            cross_attentions += (outputs.cross_attentions,)

    if output_hidden_states:
        decoder_hidden_states += (
            (outputs.decoder_hidden_states,)
            if self.config.is_encoder_decoder
            else (outputs.hidden_states,)
        )

# ... THEN do beam scorer process logic that actually picks the tokens to go with
beam_outputs = beam_scorer.process(
    input_ids,
    next_token_scores,
    # ...
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]

input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

So currently it's:

STEP_1 input_ids -> step_1_scores -> SAVE
beamsearch(step_1_scores) -> beam_scores_1 -> STEP_2 input_ids

STEP_2 input_ids -> step_2_scores -> SAVE
beamsearch(step_2_scores) -> beam_scores_2 -> STEP_3 input_ids

and I think what you want is

STEP_1 input_ids -> step_1_scores
beamsearch(step_1_scores) -> beam_scores_1 -> SAVE -> STEP_2 input_ids

STEP_2 input_ids -> step_2_scores 
beamsearch(step_2_scores) -> beam_scores_2 -> SAVE -> STEP_3 input_ids

or maybe even save before AND after beam search.

I think that @rbbb 's ultimate source of confusion is due to the fact that beam_scores_1 != step_2_scores. This will give you proper scores for the beams that generated the output, not the beams that were present when the nth token was generated.

@patrickvonplaten what do you think?

rbbb commented 2 years ago

Yes, it depends what you want in those scores. I think currently it's the score for the n beams that were present in the search at the time that token in the sequence was chosen (even if a beam gets overwritten).

For most applications, I think people would want the probability of each token at each step of sequence creation (and ultimately, perplexity). But ultimately you'd have to ask people who use the library.

patrickvonplaten commented 2 years ago

Hey @rbbb,

We've had quite some discussion about what exactly we should save in the scores list. Could you maybe take a look at this PR: https://github.com/huggingface/transformers/pull/14654 and all the linked issues to understand why beam_scores are implemented the way they are?

As you can see in the PR we don't save the conditional beam scores anymore, but the processed per token probabilities, so that it should be rather straight-forward to compute the perplexity

rbbb commented 2 years ago

@patrickvonplaten Looking at your discussion on the discuss page linked, this is exactly: "The problem now is that the scores don’t really give any information about the probability of token j at time i which is what most people seem to be interested in." So this issue is exactly another way of saying "I'd like to have the token probabilities (forward if it's autoregressive) for a given beam", which is exactly #14612, #14086, #14065 and the request at the end of #10012.

Also I think @cwkeam is right, my patch confuses scores before generating n_beams*2 next tokens and after. It does give legible output about which tokens are generated with which probability (might not be 100% correct though).

I guess I just expected to be able to do tokenizer.batch_decode(torch.topk(outscores[0,:,:], k=5)[1]) to see the top 5 tokens that could be generated at each step on the beam that generated the first phrase, and an additional softmax for the probas ...

Closing this issue for now as it is exactly similar to the previous issues mentioned.