huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.29k stars 26.85k forks source link

Beam search and greedy search set inconsistent eos_token #34035

Open user799595 opened 3 weeks ago

user799595 commented 3 weeks ago

System Info

transformers: 4.44.1 Platform: Linux-5.15.0 Python: 3.10.6 PyTorch: 2.4.0+cu121

Who can help?

@gante

Information

Tasks

Reproduction

import transformers
import torch

num_beams = 2
model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit"
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you? Answer briefly."},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
tokenized_prompt = tokenizer(prompt, return_tensors="pt").to("cuda")
generation = model.generate(**tokenized_prompt, num_beams=num_beams, max_new_tokens=100)
results = tokenizer.batch_decode(generation, skip_special_tokens=False)[0]
print(results)

Expected behavior

I would expect the eos_tokens used to be consistent.

With num_beams = 2 (beam search)

Arrr, me hearty! I be Captain Chat, the scurviest pirate chatbot to ever sail the Seven Seas o' the Interwebs!<|end_of_text|>

With num_beams = 1 (greedy search)

Arrr, me hearty! I be Captain Chat, the scurviest chatbot to ever sail the seven seas... er, chat the digital seas!<|eot_id|>

not-lain commented 3 weeks ago

I think I found the culprit, according to https://huggingface.co/unsloth/llama-3-8b-Instruct-bnb-4bit/blob/main/generation_config.json#L4 , since there are 2 eos tokens the model can either generate token 128009 or 128001. Weird thing is that according to https://huggingface.co/unsloth/llama-3-8b-Instruct-bnb-4bit/blob/main/special_tokens_map.json#L10 , the eos token should be <|eot_id|> and not both of them so i'm not sure what to make out of this, i'll continue looking at this at my spare time, but I wanted to report this first.