huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.99k stars 27k forks source link

BART model generates drastically worse output when labels are passed to forward method #19225

Closed AndreaSottana closed 2 years ago

AndreaSottana commented 2 years ago

System Info

Who can help?

@patrickvonplaten @patil-suraj

Information

Tasks

Reproduction

According to BART documentation adding the labels to the forward pass should result in the loss also being returned in the output Seq2SeqLMOutput object. However, this also seems to have the very bad side effect of drastically degrading output's performance and I cannot understand why this would be the case. I've used an example below taken from the CNN/DailyMail dataset.

data = {"input": """March 14 is my favorite day to be a nerd. Across the country, math geeks in museums, schools, private groups and elsewhere gather to celebrate the number pi, approximately 3.14. That's why March 14 -- 3-14 -- is Pi Day. What's more, Albert Einstein was born on this day. A quick refresher: Pi is defined as the distance around a perfect circle, or the circumference, divided by the distance across it, or the diameter. It is also involved in calculating the area of a circle, the volume of a sphere, and many other mathematical formulas you might need in the sciences. Throughout history, people have been captivated by this number because there is no way to calculate it exactly by a simple division on your calculator. What's more, its digits go on infinitely, without any pattern in the numbers. 3.1415926535897932 ... etc. Even that many digits are more than most people would need for everyday use, but some folks have been inspired to memorize thousands of digits of pi, or even use the digits to create poetry or music. On Pi Day, one number 'reeks of mystery'. Math may be scary, but pi is not -- as evidenced by the widespread revelry on Pi Day. One might even say -- gasp! -- it's cool to like pi these days. Even the House of Representatives supported the designation of March 14 as National Pi Day in 2009. In countries where the day is written before the month, Friday is 14-3, which looks less like pi. "And so Pi Day is an acquired taste," mathematician Jonathan Borwein, at the University of Newcastle in Australia, said in an e-mail. Conveniently, "pi" sounds like "pie," and pies are round. You could celebrate Pi Day in a casual way by grabbing a slice of pastry, or pizza. If you're in enrolled in school, your math class or math department might be doing something special already. But if you happen to live in a particularly pi-happy place, you might be able to take part in some larger-scale, pi-inspired activities. Where Pi Day began. If you want to go where the day is said to be "invented," look no further than San Francisco's Exploratorium. Larry Shaw, who worked in the electronics group at the museum, began the tradition in 1988. Last year was Pi Day's 25th anniversary there. Pi Day began as a small gathering with mostly museum staff. Now it's a public pi extravaganza featuring a "Pi procession," whose attendees get a number -- 0 to 9 -- and line up in the order of pi's digits: 3.14159265 ... you get the idea. The parade ends at the "pi shrine" -- a pi symbol with digits spiraling around it embedded in the sidewalk, which was unveiled last year. For those who can't attend in person, the Exploratorium has a Second Life Pi Day event that includes "irrational exhibits, fireworks, cheerleaders, music, and dancing." The museum also lists a bunch of educational activities to teach about the concept of pi. On Pi Day, is 'pi' under attack? Where Einstein lived. On the opposite coast, the leafy university town where Albert Einstein spent the last 22 years of his life is showing community-wide exuberance for pi. Princeton, New Jersey, kicks off Pi Day weekend on Thursday night with a reading by physicist Charles Adler, then heads into a full day of activities on Friday, including a walking tour of Einstein's neighborhood and a pizza pie-making contest. The pie-eating contest takes place at McCaffrey's supermarket, while an Einstein look-alike competition will match mustaches and wild gray hair at the Princeton Public Library. Pi fans who have been spending the last year memorizing digits can show off and compete at the library, where the winner among 7- to 13-year-olds can take home a cool pi-hundred (That is, $314.15). The Historical Society of Princeton will have an Einstein birthday party. Tetsuya Miyamoto, inventor of the KENKEN puzzle, will speak at the library as well. """,
"output": "March 14 is my favorite day to be a nerd, because in museums, schools, private groups and elsewhere peoplegather to celebrate the number pi, approximately 3.14"}

Using the example above, run the following code

import pandas as pd
import torch
from datasets import Dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from torch.utils.data import DataLoader
train_dataset = Dataset.from_pandas(pd.DataFrame(data=[data]))

tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn')
def preprocess_function(dataset):
    model_inputs = tokenizer(
        dataset["input"], max_length=1024, truncation=True, padding='do_not_pad'
    )
    # Set up the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            dataset["output"], max_length=1024, truncation=True, padding='do_not_pad'
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_all_data = train_dataset.map(preprocess_function, batched=True)
tokenized_all_data = tokenized_all_data.remove_columns(['input', 'output'])

train_dataloader = DataLoader(tokenized_all_data)
for sample in train_dataloader:
    input_ids = torch.tensor(sample['input_ids'])
    attention_mask = torch.tensor(sample['attention_mask'])
    labels = torch.tensor(sample['labels'])
    labels = torch.cat((labels, (-100 * torch.ones(1024 - labels.shape[0], dtype=int))), axis=0)

model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn')
model.eval()
with torch.no_grad():
    output_1 = model(input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0), labels=labels.unsqueeze(0))
    output_2 = model(input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0))

Now check the output text (assume greedy decoding for now just for quick visualisation, although better options could be used with beam search of course)

MarchMarch 14 is Pi favorite day to be a nerd. says it math, schools, private groups and elsewhere, celebrateather to celebrate pi number pi. approximately 3.14.......gngngngggneralgngngngngngnggggggngngngngngnigeng.igenigenigengi.igenigenigenigenigenigenigenigenigenbergbergigenigenigengiangiangiangiangiangiangiangiangiangiangian688giangianigengiangiangiangiangiangian.giangiangiangianindaindaivalentindagiangianindaivalentificent,,,,indainda,..........,..,.,..........igenangan...angananganigen..angangigiigengiigengigigigiigengianigengiangiangiangiangiangiangiangiangiangiangiangianigiigigianigigiangigiangiangianigiigiigiigiivalentivalentivalentgiangiangiangiangiangiangiangiangianivalentivalentivalentivalentgiangiangiangiangiangiangiangiangianivalentgianivalentivalentivalentivalentivalentivalentgiangiangiangiangiangiangianivalentuperuperivalentgiangiangiangiangiangiangiangian,giangiangiangianuperiberigenigen , igenuperuper uperuperuperuperuperuperuperuperuperuperuperuperuperuper uper uper uperuperuperuperuperiberiberiberiberuperiberiberiberiberiberiberiberiberiberiberiberiberiber..gigi.........iber.giiberiberiberiberiberivalentiberiberiberiberiberbergiberiberbergbergbergbergbergbergbergbergbergbergbergbergiberiberuperuperuperiberiberiberuperuperuperuperiberuperuperiberindauperindaindaindainda,,,,,,,,,,,,.,,,,,,,,,,,,,,,,,,,umberumber,umberumber,,,,,gigiumber,umberuperuper,,,,,..,,giigg,............................,,,,,,,,,,.,,,,,.,,,,,,,,,,,,,,,,,,,gigigigigi.</s>.......................................................itableitable..........itableitable...gi.....trusttrusttrusttrust.........trust.....squsqu..................squ.................................................................gi...............................,......................................................................................................................................................................................................................................................... and...............</s> and. and and and and and and and and and..- and and and and and and and and and and and and and...

MarchMarch 14 is Pi favorite day to be a nerd. Pi the country, math geeks gather museums, schools, private groups and elsewhere gather to celebrate the number pi. approximately 3.14. In's why March 14 -- 3-14 -- is Pi Day.</s>'s more, Albert Einstein was born on this day.</s> quick refresher: Pi is defined as the distance around a perfect circle, or the circumference. divided by the distance across it. or the diameter.</s> is also involved in calculating the area of a circle, the volume of a sphere, and many other mathematical formulas. might need in the sciences.</s> history, people have been captivated by this number because there is no way to calculate it exactly by a simple division on your calculator.</s>'s more, its digits go on infinitely, without any pattern in the numbers.</s>.1415926535897932... etc.</s> that many digits are more than most people would need for everyday use. but some folks have been inspired to memorize thousands of digits of pi. or even use the digits to create poetry or music.</s> Pi Day, " might'reeks of mystery', Math may be scary, but pi is not -- as evidenced by the widespread revelry on Pi Day.</s> might even say -- gasp! -- it's cool to like pi these days.</s> the House of Representatives supported the designation of March 14 as National Pi Day in 2009.</s> countries where the day is written before the month, Friday is 14-3, which looks less like pi.</s>pi so Pi Day is an acquired taste," mathematician Jonathan Borwein, at the University of Newcastle in Australia, said in an e-mail.</s>veniently, "pi" sounds like "pie," and pies are round.</s> could celebrate Pi Day in a casual way by grabbing a slice of pastry, or pizza.</s> you're in enrolled in school, your math class or math department might be doing something special already.</s> if you happen to live in a particularly pi-happy place, you might be able to take part in some larger-scale, pi-inspired activities.</s> Pi Day is.</s> you want to go where the day is said to be "invented," look no further than San Francisco's Exploratorium.</s> Shaw, who worked in the electronics group at the museum, began the tradition in 1988.</s> year was Pi Day's 25th anniversary there.</s> Day began as a small gathering with mostly museum staff.</s> it's a public pi extravaganza featuring a "Pi procession," whose attendees get a number -- 0 to 9 -- and line up in the order of pi's digits: 3.14159265... you get the idea.</s> parade ends at the "pi shrine" -- a pi symbol with digits spiraling around it embedded in the sidewalk, which was unveiled last year.</s> those who can't attend in person, the Exploratorium has a Second Life Pi Day event that includes "irrational exhibits, fireworks, cheerleaders, music, and dancing"</s> winner also lists a bunch of educational activities to teach about the concept of pi.</s> Pi Day, the 'pi' under attack?</s> Einstein spent.</s> the opposite coast, the leafy university town where Albert Einstein spent the last 22 years of his life is showing community-wide exuberance for pi.</s>, New Jersey, kicks off Pi Day weekend on Thursday night with a reading by physicist Charles Adler. then heads into a full day of activities on Friday. including a pizza tour of Einstein's neighborhood. a pizza pie-making contest.</s> winner-eating contest takes place at McCaffrey's supermarket. while an Einstein look-alike competition will match mustaches and wild gray hair at the Princeton Public Library.</s> fans who have been spending the last year memorizing digits can show off and compete at the library, where the winner among 7- to 13-year-olds can take home a cool pi-hundred (That is, $314.15).</s> Historical Society of Princeton will have an Einstein birthday party.</s>etsuya Miyamoto, inventor of the KENKEN puzzle, will speak at the library, well.</s>

Expected behavior

Output 1 and Output 2 above are completely different and I would expect them to be the same. Output 1 is by far much worse than Output 2. Why does passing or not passing the labels (see snippet below taken from code above)

with torch.no_grad():
    output_1 = model(input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0), labels=labels.unsqueeze(0))
    output_2 = model(input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0))

causes such a huge difference in the generated logits, when I would expect the only difference being that the loss is returned when passing the labels?

On a side note, the output 2 also contain a huge amount of end of sequence tokens </s>. Are we supposed to ignore everything after the first </s> has been generated?

patrickvonplaten commented 2 years ago

@ArthurZucker could you take a look here?

AndreaSottana commented 2 years ago

Hi @ArthurZucker did you manage to take a look by any chance? I tried to do some debugging for the issue above and found that this may be a contributing cause: As mention in this code comment, unlike other seq2seq models, Bart automatically creates decoder_input_ids from input_ids if no decoder_input_ids are provided. However, if the labels are provided, then the labels become the decoder_input_ids instead (see these lines of code). As far as I can see, this seems to also happen regardless of whether the model is in train or eval mode, whereas in eval mode the labels should not be fed into the decoder as only the trained model and the input should be used. This would explain why we get different outputs depending on whether I pass the labels to the forward pass or not, whereas my expectation would be that feeding the labels while in eval mode should only have the additional effect that the loss is included in the Seq2SeqLMOutput object, but it should not affect the logits.

ArthurZucker commented 2 years ago

Hey! Sorry gonna have a look tomorrow, very nice debugging 🤗 pretty sure that we would want to have the correct outputs along with the loss.

Computing the loss at evaluation is not really important? But is informative. Is that why you are trying to do that ?

AndreaSottana commented 2 years ago

Hi @ArthurZucker The reason for me calculating the evaluation loss is because I'm implementing a loss-based curriculum learning training strategy (something re-adapted from an idea introduced in this paper). Essentially I need to order the training samples (not the validation samples) based on their loss, used as a proxy for difficulty (based on the assumption that samples with higher loss are harder for the model) so as to build an ordered curriculum. However, even though I'm calculating the loss on the training dataset, I need to do it in eval mode, as I don't want the model to do any learning or backprop during this process otherwise the order in which I feed the samples to the model would affect the loss calculation as the model would start learning from previous samples, which would skew the difficulty scoring, which should be based on the loss of each training sample at a fixed snapshot of the model in time. This is not part of the model training, it's done beforehand (and potentially at regular intervals during training), so as to rank the samples in order of difficulty. Hence why I need to do it in eval mode and stumbled upon the issue mentioned above

ArthurZucker commented 2 years ago

Ok! I got it.

You were also wondering about the </s>, you can skip them when decoding using tokenizer.decode(...., skip_special_tokens = True.

About your issue, I checked, and I am pretty sure we want to change the current behavior, but it could be not backward compatible. A possible update would be to remove the update of the decoder_input_ids depending on the labels. I don't really know where that comes from, but if the results are clearly better (as it is now) we might want to do this. Let me check that and come back to you!

In the mean time, I found a quick fix :

train_dataloader = DataLoader(tokenized_all_data)
for sample in train_dataloader:
    input_ids = torch.tensor(sample['input_ids'])
    attention_mask = torch.tensor(sample['attention_mask'])
    labels = torch.tensor(sample['labels'])
    labels = torch.cat((labels, (-100 * torch.ones(input_ids.shape[0] - labels.shape[0], dtype=int))), axis=0)

model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn')
model.eval()
with torch.no_grad():
    output_1 = model(input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0), labels=labels.unsqueeze(0), decoder_input_ids = shift_tokens_right(input_ids.unsqueeze(0), model.config.pad_token_id, model.config.decoder_start_token_id))
    output_2 = model(input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0))

Tell me if that works for you?!

AndreaSottana commented 2 years ago

Hi @ArthurZucker Thanks so much for providing an interim solution, I can confirm your fix works fine, and returns the correct loss I was looking for. Glad that we uncovered this. Thanks also for confirming how to skip </s> tokens. Feel free to close the issue or keep it open for future work at your discretion, in the meantime I will use the solution you provided.

ArthurZucker commented 2 years ago

Closing this as it makes more sens to keep the current behavior