huggingface / transformers

πŸ€— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.88k stars 26.98k forks source link

repetition_penalty not being applied #29080

Closed adenhaus closed 7 months ago

adenhaus commented 8 months ago

System Info

Who can help?

@gante @ArthurZucker @younesbelkada

Information

Tasks

Reproduction

Download a model and do inference. Change the repetitionpenalty. You will see the output does not change

Expected behavior

The output should change when repetition_penalty is changed.

borisdayma commented 8 months ago

Iβ€―just noticed the same issue. I think it would be a useful feature.

gante commented 8 months ago

repetition_penalty is working on my end:

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

inputs = tokenizer(["The quick brown"], return_tensors="pt")
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=100)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
gen_out_repetition_penalty = model.generate(**inputs, do_sample=False, max_new_tokens=100, repetition_penalty=1.5)
decoded_repetition_penalty = tokenizer.batch_decode(gen_out_repetition_penalty, skip_special_tokens=True)
print(decoded == decoded_repetition_penalty)
# False

If you're still seeing this issue, I will need a reproducer to figure out what's wrong πŸ€—

borisdayma commented 8 months ago

It’s related to flax models only. Not sure if that’s what @adenhaus meant.

adenhaus commented 8 months ago

@borisdayma no I noticed it on an mt5 finetuned model, not a flax model. Happy to provide code to reproduce but the day after I posted this it was working again.

adenhaus commented 8 months ago

repetition_penalty is working on my end:


from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

model = AutoModelForCausalLM.from_pretrained("distilgpt2")

inputs = tokenizer(["The quick brown"], return_tensors="pt")

gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=100)

decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

gen_out_repetition_penalty = model.generate(**inputs, do_sample=False, max_new_tokens=100, repetition_penalty=1.5)

decoded_repetition_penalty = tokenizer.batch_decode(gen_out_repetition_penalty, skip_special_tokens=True)

print(decoded == decoded_repetition_penalty)

# False

If you're still seeing this issue, I will need a reproducer to figure out what's wrong πŸ€—

@gante i set repetition_penalty in the generation_config.json on the hub, didn't pass it to model.generate. Not sure if that makes a difference

borisdayma commented 8 months ago

Then maybe I should create a separate issue for flax models which don't seem to support this option.

gante commented 8 months ago

@borisdayma correct, the option is not supported on flax πŸ€— (if you open the issue, please mention that unsupported flags on a given framework should raise a warning, we are not raising them atm)

gante commented 8 months ago

@adenhaus

i set repetition_penalty in the generation_config.json on the hub, didn't pass it to model.generate. Not sure if that makes a difference

This should be fine πŸ€—

I noticed it on an mt5 finetuned model

This may be the cause! mt5 is an encoder-decoder model, have you tried the encoder_repetition_penalty flag instead? Or maybe both encoder_repetition_penalty and repetition_penalty, depending on your use case (the former acts on the input text of mt5, the later acts on the generated text).

adenhaus commented 8 months ago

@gante I want it to act on the generated text. And now I am noticing the issue again. I change the repetition_penalty in the generation_config file on the hub from 1.0 to 1.9 but see no difference in the outputs.

Here are steps to reproduce:

I am using this model from the hub.

With this code:

import torch
import transformers
import pandas as pd

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
)

# Model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = 'adenhaus/mt5-small-eng-tata-blueprints'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

def split_verbalisation(text):
  split_text = text.split("Verbalisation: ")

  if len(split_text) > 1:
    return split_text[1]
  else:
    return text

# Predict function
def generate_verbalisation(model, tokenizer, example):
    input_ids = tokenizer(example)["input_ids"]
    input_ids = torch.LongTensor(input_ids).view(1, -1).to(device)
    generated_ids = model.generate(input_ids, max_new_tokens=200)
    prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    prediction = split_verbalisation(prediction)

    print(prediction)
    return prediction

# Load test set
df = pd.read_csv("/csv_path", sep='\t')

# Generate output dataset for evaluation
out_df = pd.DataFrame(columns=['preds', 'refs'])
out_df['refs'] = df['target']
out_df['preds'] = df['linearized_input'].apply(lambda x: generate_verbalisation(model, tokenizer, x))
out_df['linearized_input'] = df['linearized_input']

out_df.to_csv('small-eng-blueprints-preds.csv', sep='\t', index=False)

And here is a sample of the (tab separated) csv I'm loading:

linearized_input    target
Median ages at first sex, first marriage, and birth of first child among men age 30-34 by residence | Age (years) | (First sex, Urban, 18.4) (First marriage, Urban, 25.6) (Birth of first child, Urban, 26.9) (First sex, Rural, 18.5) (First marriage, Rural, 24.1) (Birth of first child, Rural, 25.1) (First sex, Urban, 21.3) (First marriage, Urban, 27.7) (Birth of first child, Urban, 28.2) (First sex, Rural, 20.8) (First marriage, Rural, 25) (Birth of first child, Rural, 25.6) (First sex, Urban, 21.5) (First marriage, Urban, 27.4) (Birth of first child, Urban, 29.2) (First sex, Rural, 22.1) (First marriage, Rural, 25.5) (Birth of first child, Rural, 26.8) (First sex, Urban, 21.7) (First marriage, Urban, 27) (Birth of first child, Urban, 29.4) (First sex, Rural, 21) (First marriage, Rural, 22.8) (Birth of first child, Rural, 24.9) (First sex, Urban, 22.6) (First marriage, Urban, 27.7) (Birth of first child, Urban, 27.8) (First sex, Rural, 22.5) (First marriage, Rural, 25.1) (Birth of first child, Rural, 26) (First sex, Urban, 18.6) (First marriage, Urban, 25.5) (Birth of first child, Urban, 26.4) (First sex, Rural, 18.4) (First marriage, Rural, 22.5) (Birth of first child, Rural, 23.6) (First sex, Urban, 25.7) (First marriage, Urban, 26.1) (Birth of first child, Urban, 28.3) (First sex, Rural, 23.8) (First marriage, Rural, 23.8) (Birth of first child, Rural, 26.2) (First sex, Urban, 20.9) (First marriage, Urban, 22.2) (Birth of first child, Urban, 24.5) (First sex, Rural, 20.1) (First marriage, Rural, 20.7) (Birth of first child, Rural, 23.7) A number of the differences between rural and urban areas are common across most of the countries.
Family formation trajectories among men age 30-34, Benin | (Timing of first sex, Earlier, 0.32) (Timing of first marriage, Earlier, 0.29) (Timing of birth of first child, Earlier, 0.3) (Timing of first sex, Earlier, 278) (Timing of first marriage, Earlier, 256) (Timing of birth of first child, Earlier, 266) (Timing of first sex, Typical, 0.5) (Timing of first marriage, Typical, 0.38) (Timing of birth of first child, Typical, 0.39) (Timing of first sex, Typical, 443) (Timing of first marriage, Typical, 335) (Timing of birth of first child, Typical, 344) (Timing of first sex, Later, 0.18) (Timing of first marriage, Later, 0.33) (Timing of birth of first child, Later, 0.31) (Timing of first sex, Later, 160) (Timing of first marriage, Later, 290) (Timing of birth of first child, Later, 271)   In the first Sankey diagram for Benin (Figure 6), 32% of men experienced first sexual intercourse earlier-than-typical, 50% at typical timing, and 18% later-than-typical.
Family formation trajectories among men age 30-34, Benin | (Timing of first sex, Earlier, 0.32) (To Earlier, Earlier, 0.12) (To Typical, Earlier, 0.1) (To Later, Earlier, 0.1) (Timing of first marriage, Earlier, 0.29) (To Earlier, Earlier, 0.09) (Timing of birth of first child, Earlier, 0.3) (Timing of first sex, Earlier, 278) (Timing of first marriage, Earlier, 256) (Timing of birth of first child, Earlier, 266) (Timing of first sex, Typical, 0.5) (To Earlier, Typical, 0.15) (To Typical, Typical, 0.21) (To Later, Typical, 0.15) (Timing of first marriage, Typical, 0.38) (To Earlier, Typical, 0.11) (To Typical, Typical, 0.15) (To Later, Typical, 0.11) (Timing of birth of first child, Typical, 0.39) (Timing of first sex, Typical, 443) (Timing of first marriage, Typical, 335) (Timing of birth of first child, Typical, 344) (Timing of first sex, Later, 0.18) (To Earlier, Later, 0.02) (To Typical, Later, 0.07) (To Later, Later, 0.09) (Timing of first marriage, Later, 0.33) (To Typical, Later, 0.06) (To Later, Later, 0.08) (Timing of birth of first child, Later, 0.31) (Timing of first sex, Later, 160) (Timing of first marriage, Later, 290) (Timing of birth of first child, Later, 271) In the first Sankey diagram for Benin (Figure 6), 32% of men experienced first sexual intercourse earlier-than-typical, 50% at typical timing, and 18% later-than-typical.
Youth empowerment (pooled terciles) among women age 15-29 by country | (High, Mali, 13.2) (Medium, Mali, 17.9) (Low, Mali, 68.9) (High, Ethiopia, 14.7) (Medium, Ethiopia, 22.7) (Low, Ethiopia, 62.6) (High, Malawi, 15.8) (Medium, Malawi, 55) (Low, Malawi, 29.2) (High, Uganda, 25.4) (Medium, Uganda, 32.1) (Low, Uganda, 42.4) (High, Zambia, 27.4) (Medium, Zambia, 28.3) (Low, Zambia, 44.3) (High, Nigeria, 34.2) (Medium, Nigeria, 35.5) (Low, Nigeria, 30.3) (High, Haiti, 42.8) (Medium, Haiti, 40.1) (Low, Haiti, 17.1) (High, Nepal, 43.5) (Medium, Nepal, 37) (Low, Nepal, 19.5) (High, Senegal, 43.9) (Medium, Senegal, 21.5) (Low, Senegal, 34.6) (High, Philippines, 80.6) (Medium, Philippines, 13.7) (Low, Philippines, 5.7)    A mere 13% of young women are in the high empowerment tercile in Mali.
Ownership of House and Land | Percent of women and men age 15-49 who: | (Women, Own a home alone or jointly, 18) (Men, Own a home alone or jointly, 40) (Women, Own land alone or jointly, 15) (Men, Own land alone or jointly, 34) Only 18% of women own a house, either alone or jointly, and only 15% own land.
Attitudes toward Wife Beating | Percent of women and men age15-49 who believe that a husband is justified in beating his wife for the following reasons: | (Women, Wife burns the food, 14) (Men, Wife burns the food, 8) (Women, Wife argues with him, 21) (Men, Wife argues with him, 13) (Women, Wife goes out without telling him, 25) (Men, Wife goes out without telling him, 13) (Women, Wife neglects children, 25) (Men, Wife neglects children, 14) (Women, Wife refuses to have sex with him, 19) (Men, Wife refuses to have sex with him, 11) (Women, Any of these, 35) (Men, Any of these, 25) More than one-third of women (35%) and one-quarter of men agree that a husband is justified in beating his wife for at least one of these reasons: if she burns the food, argues with him, goes out without telling him, neglects the children, or refuses to have sex with him.
Employment Characteristics among Working Women | Percent | (All year, 62) (Seasona, 32) (Occasional, 6) (Employed by family member, 9) (Employed by non-family member, 30) (Self-employed, 62) (Cash only, 65) (Cash and in-kind, 11) (In-kind only, 2) (Not paid, 23)  the percent distribution of employed women age 15-49 by type of earnings and employer characteristics, according to type of employment (agricultural or non-agricultural).
Trends in Total Fertility Rate, Kenya 1975-2008* | Total Fertility Rate | (1975-78, 8.1) (1984-88, 6.7) (1990-92, 5.4) (1995-97, 4.7) (2000-02-01 00:00:00, 4.9) (2006-08-01 00:00:00, 4.6) The data indicate that the TFR declined during the 1980s and 1990s, changing from a high of 8.1 children per woman in the late 1970s to 6.7 in the late 1980s, and dropping to 4.7 during the last half of the 1990s.
Trends in Contraceptive Use, Kenya 1978-2008 (percentage of currently married women using any method) | Percent | (1978, 7) (1984, 17) (1989, 27) (1993, 33) (1998, 39) (2003, 39) (2008-09-01 00:00:00, 46)    there has been a substantial increase in contraceptive use since the late 1970s, from 7 percent of married women in 1978 to 46 percent in 2008-09.
Percentage of Currently Married Women Whose Husbands Have At Least One Other Wife | Percent | (Urban, 7) (Rural, 15) (Nairobi, 2) (Central, 3) (Coast, 15) (Eastern, 5) (Nyanza, 21) (Rift Valley, 15) (Western, 23) (North Eastern, 36) (No education, 33) (Primary incomplete, 17) (Primary complete, 8) (Secondary+, 8)  Thirteen percent of currently married women live in polygynous unions (i.e., they have one or more co-wives).
Planning Status of Births | Percent | (Wanted then, 0.57) (Unwanted, 0.17) (Wanted later, 0.26) 17 percent of births in Kenya are unwanted, and 26 percent are mistimed (wanted later).
Trends in Receipt of Antenatal Care from a Skilled Medical Provider, Kenya 2003-2008 | Percentage of women with live birth in the past 5 years | (2003, 88) (2008-09-01 00:00:00, 92)   The 2008-09 data indicate a rise since 2003 in medical antenatal care coverage
Percentage of Children Age 12-23 Months with Specific Vaccinations | Percent | (BCG, 96) (DPT 1 - HepB - Hib, 96) (DPT 2 - HepB - Hib, 93) (DPT 3 - HepB - Hib, 86) (Polio 1, 96) (Polio 2, 94) (Polio 3, 88) (Measles, 85) (All basic vaccinations, 77) (No vaccinations, 3)   77 percent of children age 12-23 months are fully vaccinated at any time before the survey.
Infant and Young Child Feeding (IYCF) Practices | Percent | (Breastfed, 44) (Non-breastfed, 16) (Total, 39) Breastfed children are much more likely to be fed in accordance with IYCF practices than non-breastfed children
Number of Decisions in Which Women Participate | Percent of women | (Number of decisions, Percent of women) (0, 3) (1, 6) (2, 9) (3, 13) (4, 19) (5, 50)    the distribution of currently married women according to the number of decisions in which they participate.
Infant mortality rate in the 10 years preceding the survey by selected demographic characteristics | Deaths per 1,000 live births | (<20, 95) (20-29, 71) (30-39, 73) (40-49, 100) (1, 83) (2021-02-03 00:00:00, 65) (2021-04-06 00:00:00, 72) (7+, 103) (<2 years, 122) (2 years, 67) (3 years, 46) (4+ years, 45) (Male, 84) (Female, 70) This is true for all categories of mortality. With the exception of mothers in the 40-49 age group, infant mortality is higher for mothers under age 20 than for older mothers.
Mother’s duration of stay in the health facility after giving birth | Percentage | (<6 hours, Vaginal birth, 29) (6-11 hours, Vaginal birth, 14) (12-23 hours, Vaginal birth, 6) (1-2 days, Vaginal birth, 41) (3+ days, Vaginal birth, 10) (<6 hours, Caesarean birth, 7) (6-11 hours, Caesarean birth, 1) (12-23 hours, Caesarean birth, 1) (1-2 days, Caesarean birth, 9) (3+ days, Caesarean birth, 81) the percent distribution of women who gave birth in a health facility in the five years preceding the survey by duration of stay in the facility and type of delivery.
Percentage of children age 12-23 months with specific vaccinations | Percentage vaccinated at any time before the survey | (BCG, 0.51) (DPT 1, 0.51) (DPT 2, 0.46) (DPT 3, 0.38) (Polio 0, 0.47) (Polio 1, 77) (Polio 2, 70) (Polio 3, 54) (Measles, 42) (All bascic vaccinations, 25) (No vaccinations, 21)    the percentage of children age 12-23 months who have received the various vaccinations by source of information (vaccination card or mother’s report).
Trends in vaccination coverage among children age 12-23 months, 2003-2013 | Percentage | (2003 NDHS, BCG, 48) (2008 NDHS, BCG, 50) (2013 NDHS, BCG, 51) (2003 NDHS, DPT 3, 21) (2008 NDHS, DPT 3, 35) (2013 NDHS, DPT 3, 38) (2003 NDHS, Polio 3, 29) (2008 NDHS, Polio 3, 39) (2013 NDHS, Polio 3, 54) (2003 NDHS, Measles, 36) (2008 NDHS, Measles, 41) (2013 NDHS, Measles, 42) (2003 NDHS, All basic vaccinations, 13) (2008 NDHS, All basic vaccinations, 23) (2013 NDHS, All basic vaccinations, 25) (2003 NDHS, No vaccinations, 27) (2008 NDHS, No vaccinations, 29) (2013 NDHS, No vaccinations, 21)   The proportion of children who received none of the six basic vaccinations declined marginally by 6 percentage points from the 2003 level
IYCF indicators on breastfeeding status | Percentage of children | (Exclusive breastfeeding under age 6 months, 17) (Exclusive breastfeeding at age 4-5 months, 10) (Continued breastfeeding at 1 year, 84) (Introduction of solid, semisolid, or soft foods (6-8 months), 67) (Continued breastfeeding at 2 years, 35) (Age-appropriate breastfeeding (0-23 months), 52) (Predominant breastfeeding (0-5 months), 69) (Bottle feeding (0-23 months), 13)   the 2013 NDHS results for key IYCF breastfeeding practices among children under age 2 who are living with their mothers.
Ownership of, access to, and use of ITNs | Percent | (Percent of households with at least one ITN, 50) (Percent of households with at least one ITN for every two persons who stayed in the household the night before the interview, 22) (Percent of the household population with access to an ITN within their household, 36) (Percent of the household population who slept under an ITN, 13)   50 percent of households have at least one ITN.
Trends in age of first sexual intercourse | Percent | (2008 NDHS, Percentage of women 15-19 who had sexual intercourse before exact age 15, 15) (2013 NDHS, Percentage of women 15-19 who had sexual intercourse before exact age 15, 16) (2008 NDHS, Percentage of men 15-19 who had sexual intercourse before exact age 15, 6) (2013 NDHS, Percentage of men 15-19 who had sexual intercourse before exact age 15, 3) (2008 NDHS, Percentage of women 18-19 who had sexual intercourse before exact age 18, 53) (2013 NDHS, Percentage of women 18-19 who had sexual intercourse before exact age 18, 53) (2008 NDHS, Percentage of men 18-19 who had sexual intercourse before exact age 18, 26) (2013 NDHS, Percentage of men 18-19 who had sexual intercourse before exact age 18, 21) Among young women, there was practically no change in the proportion who had sexual intercourse before age 15 or age 18 in the five-year period between the two surveys.
Current use by method | (Not currently using) (Any traditional method) (Other modern method) (Female sterilization) (Injectable) (Pill) (IUD) (0.415) (0.016) (0.011) (0.012) (0.085) (0.16) (0.301)    The 2014 EDHS findings revealed that 59 percent of currently married women in Egypt are currently using a contraceptive method.
Treatment practices among children ill with diarrhea | Percent | (Health provider consulted) (Given oral rehydration salt/home solution) (Given antibiotic) (55) (30) (37)  Table 11.11 shows that, for the majority of children ill with diarrhea, feeding practices did not conform to the recommended practices.
adenhaus commented 8 months ago

@gante was having this issue on my cluster. Now not seeing the issue in a colab notebook. So weird. Maybe a caching thing?

github-actions[bot] commented 7 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

chadj2 commented 2 months ago

I am testing with GPTBigCodeForCausalLM and I can confirm that repetition_penalty is not affecting the output. I noticed this when I was seeing different output with vllm inferencing. This issue only happens when I use beam search.

The reason for this difference is because vllm was honoring repetition_penalty and Huggingface was not when beam search is enabled. Here are my generation params:

{
  "early_stopping": true,
  "eos_token_id": 49155,
  "max_new_tokens": 512,
  "min_new_tokens": 16,
  "num_beams": 4,
  "output_scores": true,
  "return_dict_in_generate": true
}
gante commented 2 months ago

@chadj2 πŸ‘‹

Would you be able to share a more precise reproducer? (Which transformers/torch versions? Which hardware? Which model checkpoint? Were you checking at a logit level?)

We have several tests checking that flag, including with beam search, but there are always corner cases :)

chadj2 commented 2 months ago

@gante Here is the program to reproduce. I generate the same text with repetition_penalty=1 and repetition_penalty=2. You can see repetition_penalty does not affect the result even when set to max. This is with transformers=4.44.0.

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch

def test_rp():
    # disabling because this crashes on V100's
    torch.backends.cuda.enable_mem_efficient_sdp(False)

    # sanatoder seems to give different outputs for repetition_penalty
    #checkpoint = "bigcode/gpt_bigcode-santacoder"

    # starcoder gives almost exactly the same output for both repetition penalties
    checkpoint = "bigcode/starcoder"

    tokenizer = AutoTokenizer.from_pretrained(checkpoint,
                                              clean_up_tokenization_spaces=False)
    model = AutoModelForCausalLM.from_pretrained(checkpoint, 
                                                device_map='auto',
                                                torch_dtype=torch.float16)

    prompt = "def print_fibonacci():"
    tok_prompt = tokenizer([prompt], return_tensors='pt').to("cuda")

    generation_config = GenerationConfig(
        early_stopping=True,
        pad_token_id = tokenizer.eos_token_id,
        max_new_tokens=100,
        num_beams=4,
        repetition_penalty=1.0)

    print(f"inputs: {tok_prompt.input_ids[0]}")
    print(generation_config)
    outputs = model.generate(**tok_prompt,
                            generation_config=generation_config,
                            tokenizer=tokenizer)
    decoded_rp1 = tokenizer.decode(outputs[0])
    print(f"outputs_rp1: {outputs[0]}")
    print(f"decoded_rp1*************:\n{decoded_rp1}\n*************")

    generation_config.repetition_penalty = 2.0
    print(generation_config)
    outputs = model.generate(**tok_prompt,
                            generation_config=generation_config,
                            tokenizer=tokenizer)
    decoded_rp2 = tokenizer.decode(outputs[0])
    print(f"outputs_rp2: {outputs[0]}")
    print(f"decoded_rp2*************:\n{decoded_rp2}\n*************")

if __name__ == "__main__":
    test_rp()

Here is the output:

Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:11<00:00,  1.70s/it]
inputs: tensor([  589,  1459,    81, 26423, 34682,  2262], device='cuda:0')
GenerationConfig {
  "early_stopping": true,
  "max_new_tokens": 100,
  "num_beams": 4,
  "pad_token_id": 0
}

outputs_rp1: tensor([  589,  1459,    81, 26423, 34682,  2262,   284,   312,    30,   323,
          280,   225,    34,    30,   225,    35,   284,  2218,  2933,    44,
          291,  1459,    26,    84,    27,   291,   312,    30,   323,   280,
          323,    30,   312,   474,   323,   478,   203,   325,  1156,   426,
          505,   610, 11477,  1831, 14935,   284,  1459,    81, 26423, 34682,
          346,   203,     0], device='cuda:0')
decoded_rp1*************:
def print_fibonacci():
    a, b = 0, 1
    while True:
        print(b)
        a, b = b, a + b

if __name__ == '__main__':
    print_fibonacci()
<|endoftext|>
*************
GenerationConfig {
  "early_stopping": true,
  "max_new_tokens": 100,
  "num_beams": 4,
  "pad_token_id": 0,
  "repetition_penalty": 2.0
}

outputs_rp2: tensor([  589,  1459,    81, 26423, 34682,  2262,   284,   312,    30,   323,
          280,   225,    34,    30,   225,    35,   284,  2218,  2933,    44,
          291,  1459,    26,    84,    27,   291,   312,    30,   323,   280,
          323,    30,   312,   474,   323,   478,   203,   325,  1156,   426,
          505,   610, 11477,  1831, 14935,   284,  1459,    81, 26423, 34682,
          346,   203,     0], device='cuda:0')
decoded_rp2*************:
def print_fibonacci():
    a, b = 0, 1
    while True:
        print(b)
        a, b = b, a + b

if __name__ == '__main__':
    print_fibonacci()
<|endoftext|>
*************
gante commented 2 months ago

@chadj2 On my end, the two cases actually have a mismatched token (many things can explain this mismatch). Moreover, if I increase the penalty to 1000 the resulting text is significantly different.

To be more precise, the comparison should be made at a token level, as opposed to at a string level -- token decoding may hide differences. More importantly, to determine whether repetition_penalty has an impact or not, we must compare the scores before token selection. Try running the script I share below, and let me know if you get similar results πŸ€—

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch

def test_rp():
    # disabling because this crashes on V100's
    torch.backends.cuda.enable_mem_efficient_sdp(False)

    # sanatoder seems to give different outputs for repetition_penalty
    #checkpoint = "bigcode/gpt_bigcode-santacoder"

    # starcoder gives almost exactly the same output for both repetition penalties
    checkpoint = "bigcode/starcoder"

    tokenizer = AutoTokenizer.from_pretrained(checkpoint,
                                              clean_up_tokenization_spaces=False)
    model = AutoModelForCausalLM.from_pretrained(checkpoint,
                                                device_map='auto',
                                                torch_dtype=torch.float16)

    prompt = "def print_fibonacci():"
    tok_prompt = tokenizer([prompt], return_tensors='pt').to("cuda")

    generation_config = GenerationConfig(
        early_stopping=True,
        pad_token_id = tokenizer.eos_token_id,
        max_new_tokens=100,
        num_beams=4,
        repetition_penalty=1.0,
        return_dict_in_generate=True,
        output_scores=True,
    )

    print("Generating with RP=1.0")
    outputs_rp1 = model.generate(**tok_prompt,
                            generation_config=generation_config,
                            tokenizer=tokenizer)
    sequence_rp1 = outputs_rp1.sequences[0]
    scores_rp1 = outputs_rp1.scores

    print("Generating with RP=2.0")
    generation_config.repetition_penalty = 2.0
    outputs_rp2 = model.generate(**tok_prompt,
                            generation_config=generation_config,
                            tokenizer=tokenizer)
    sequence_rp2 = outputs_rp2.sequences[0]
    scores_rp2 = outputs_rp2.scores

    differences = ~(sequence_rp1 == sequence_rp2)
    print(f"Are the outputs the same? {not differences.any()}")
    if differences.any():
        print("Diverging tokens:")
        print("RP=1.0")
        print(sequence_rp1[differences])
        print("RP=2.0")
        print(sequence_rp2[differences])

    same_scores = "True"
    for generated_token_idx in range(len(scores_rp1)):
        if not torch.allclose(scores_rp1[generated_token_idx], scores_rp2[generated_token_idx], atol=1e-5):
            same_scores = f"False, scores diverged on the generated token with index {generated_token_idx}"
            previously_seen_token = tok_prompt.input_ids[0, -1]
            previously_seen_token_score_rp1 = scores_rp1[generated_token_idx][0, previously_seen_token]
            previously_seen_token_score_rp2 = scores_rp2[generated_token_idx][0, previously_seen_token]
            break
    print(f"Are the scores the same? {same_scores}")
    if same_scores != "True":
        print(f"On generated token with index {generated_token_idx}:")
        print(f"RP=1.0 score for previously seen token {previously_seen_token}: {previously_seen_token_score_rp1}")
        print(f"RP=2.0 score for previously seen token {previously_seen_token}: {previously_seen_token_score_rp2}")

if __name__ == "__main__":
    test_rp()

outputs on my end (RTX4090, torch 2.4, transformers 4.45.dev0)

Generating with RP=1.0
Generating with RP=2.0
Are the outputs the same? False
Diverging tokens:
RP=1.0
tensor([83], device='cuda:0')
RP=2.0
tensor([84], device='cuda:0')
Are the scores the same? False, scores diverged on the generated token with index 0
On generated token with index 0:
RP=1.0 score for previously seen token 2262: -17.363876342773438
RP=2.0 score for previously seen token 2262: -34.727752685546875

☝️ note that the log score for the previously seen token with repetition_penalty=2.0 is exactly the score of repetition_penalty=1.0 times 2.0