Closed chrisliu298 closed 2 months ago
Hey @chrisliu298 -- we abstracted away from the HuggingFace pipeline to our own one to simplify some of this.
Does the HF Tokenizer add a bos_token
by default, see here https://github.com/allenai/reward-bench/blob/main/rewardbench/models/pipeline.py
Ah okay, just ran an example, this seems right. Hmm looking.
Minimal example:
from transformers import AutoTokenizer
tokenizer= AutoTokenizer.from_pretrained("oobabooga/llama-tokenizer")
out = tokenizer.call("Testing my text")
print(out)
print(tokenizer.bos_token)
print(tokenizer.convert_ids_to_tokens(out['input_ids'])
@chrisliu298 -- it varies by model. Some examples (that should likely be fixed).
I think the solution is in the standard inference pipeline, check if bos_token
gets doubled
Example with the failure:
>>> chat = [
... {"role": "user", "content": "Hello, how are you?"},
... {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
... {"role": "user", "content": "I'd like to show off how chat templating works!"},
... ]
>>>
>>> tokenizer.
KeyboardInterrupt
>>> chat = tokenizer.apply_chat_template(chat, tokenize=False)
>>> chat
"<s><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n<|reward|>"
>>> tokenizer(chat)
{'input_ids': [1, 1, 92543, 1008, 364, 9843, 328, 1392, 657, 629, 345, 92542, 364, 92543, 525, 11353, 364, 295, 2940, 3890, 2395, 281, 2745, 777, 489, 1638, 629, 3514, 345, 92542, 364, 92543, 1008, 364, 295, 4330, 1217, 442, 1620, 1147, 1392, 6392, 1708, 631, 1237, 4437, 346, 92542, 364, 92527], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
I'll submit a small fix for off by one errors in the default pipeline.
The top model with this issue that I've found is https://huggingface.co/weqweasdas/RM-Mistral-7B/blob/main/tokenizer_config.json, most models use custom code or do not have BOS token. The specific implementation of InternLM could be wrong, but I consider that separate.
I noticed that, for default (sequence classification) models with chat template defined in the tokenizer,
scripts/run_rm.py
formats each conversation bytokenizer.apply_chat_template
(via the functionprepare_dialogue_from_tokenizer
) and then uses the text classificationpipeline
to process the formatted conversations. Given that 1) many models' tokenizers (e.g., Llama-3 instruct series, Gemma-2 instruct series, etc.) define thebos_token
in the chat template, and 2) thepipeline
adds anotherbos_token
during tokenization, does it mean these models read in two bos tokens in the forward pass?I also realized that some models (e.g., ArmoRM) inherently avoids this potential issue via customized pipeline by directly performing tokenization using
tokenizer.apply_chat_template
(as opposed to first formatting, then tokenizing).