huggingface / tokenizers

💥 Fast State-of-the-Art Tokenizers optimized for Research and Production
https://huggingface.co/docs/tokenizers
Apache License 2.0
8.92k stars 776 forks source link

Llama-3 offset-mapping needs fixing #1553

Open davidb-cerebras opened 3 months ago

davidb-cerebras commented 3 months ago

Opening a new issue for the previously opened issue here -- https://github.com/huggingface/tokenizers/issues/1517

Here we can see that the desired behavior for return_offsets_mapping from Mistral gives character indices corresponding to tokens:

(Pdb) from transformers import AutoTokenizer
(Pdb) tok_mistral = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
(Pdb) tok_mistral(["Sample input"], return_offsets_mapping=True)
{'input_ids': [[1, 27797, 2787]], 'attention_mask': [[1, 1, 1]], 'offset_mapping': [[(0, 0), (0, 6), (6, 12)]]}
(Pdb) tok_mistral.convert_ids_to_tokens([1, 27797, 2787])
['<s>', '▁Sample', '▁input']
(Pdb) "Sample input"[0:6]
'Sample'
(Pdb) "Sample input"[6:12]
' input'

But for Llama-3 they are not correct

(Pdb) tok_llama3 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") 
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
(Pdb) tok_llama3(["Sample input"], return_offsets_mapping=True)
{'input_ids': [[128000, 18031, 1988]], 'attention_mask': [[1, 1, 1]], 'offset_mapping': [[(0, 0), (0, 0), (6, 6)]]}

We can also see Llama-2 and GPT-2 working the same as Mistral, so Llama-3 is definitely the one performing behavior that is unexpected

(Pdb) tok_llama2 = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
(Pdb) tok_llama2(["Sample input"], return_offsets_mapping=True)
{'input_ids': [[1, 21029, 1881]], 'attention_mask': [[1, 1, 1]], 'offset_mapping': [[(0, 0), (0, 6), (6, 12)]]}
(Pdb) tok_gpt2 = AutoTokenizer.from_pretrained("openai-community/gpt2") 
(Pdb) tok_gpt2(["Sample input"], return_offsets_mapping=True)
{'input_ids': [[36674, 5128]], 'attention_mask': [[1, 1]], 'offset_mapping': [[(0, 6), (6, 12)]]}
davidb-cerebras commented 3 months ago

@ArthurZucker Is it possible to fix this in tokenizers ?

ArthurZucker commented 3 months ago

Yep, you are right, I'll dive a bit to see why we have this!

davidb-cerebras commented 3 months ago

Awesome thank you!

maximilianmordig commented 3 months ago

@ArthurZucker Is there a workaround in the meantime?

ArthurZucker commented 2 months ago

sorry not yet! I am fixing bunch of stuff, maybe #1568 ?

davidb-cerebras commented 2 months ago

@maximilianmordig Cerebras has implemented a wrapper that corrects the buggy method, feel free to use the wrapper class here: https://github.com/Cerebras/modelzoo/blob/main/src/cerebras/modelzoo/data_preparation/data_preprocessing/custom_tokenizer_example/CustomLlama3Tokenizer.py

github-actions[bot] commented 1 month ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

srinjoym-cerebras commented 1 month ago

Hey, any update on this?

ArthurZucker commented 1 month ago

Hey! Sorry not yet, it's no my stack, and will investigate for the next release as there is a need from all of you! 🤗

tcleberg commented 1 week ago

Is there any whose stack this is who can try to resolve this?

ArthurZucker commented 2 days ago

I think it's ignore_merges