Open kiva12138 opened 5 months ago
I'm having the same issue. Neither of these change the encodings:
tokenizer.add_bos_token = False
tokenizer.add_eos_token = True
cc @ArthurZucker
Hey! This is related to #30607, the tokenizer for Llama3 is a PreTrainedTokenizerFast
, not the LLamaTokenizer
or a LlamaTokenizerFast
. Though it might actually be good to support an easy way to add bos and eos. Currently what you have to do is update the TemplateProcessor
which is fairly annoying (not beginner friendly).
That's something which should be handle on the tokenizers
side
Hey! This is related to #30607, the tokenizer for Llama3 is a
PreTrainedTokenizerFast
, not theLLamaTokenizer
or aLlamaTokenizerFast
. Though it might actually be good to support an easy way to add bos and eos. Currently what you have to do is update theTemplateProcessor
which is fairly annoying (not beginner friendly).That's something which should be handle on the
tokenizers
side
@ArthurZucker I think it's called TemplateProcessing
, not TemplateProcessor
. For those wondering this is how I used it to get the tokenizer to put the eos token:
bos = "<|begin_of_text|>"
eos = "<|end_of_text|>"
tokenizer._tokenizer.post_processor = processors.Sequence(
[
processors.ByteLevel(trim_offsets=False),
processors.TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[
(bos, tokenizer.bos_token_id),
(eos, tokenizer.eos_token_id),
],
),
]
)
Now I'm worried that the padding tokens won't get added properly, but that's a different issue...
Padding token is unrelated, it's added if you ask the tokenizer to pad the input! And yes, thanks for providing the snippet @eyloncaplan 😉
In case anyone else is blocked by this issue, I copied code from #31316 into a function which patches the tokenizer to support dynamically setting add_bos_token
and add_eos_token
.
—prints:
Load plain tokenizer
Default: [128000, 64, 293]
Add EOS: [128000, 64, 293]
Load and patch tokenizer
Add EOS: [128000, 64, 293, 128009]
Don't add: [128000, 64, 293]
System Info
Platform = Windows PyTorch = 2.3.0 Transformers = 4.41.0
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
All of the statements above produce
[128000, 6151, 11, 1268, 527, 499, 3432, 30]
Expected behavior
I think when using
tokenizer = AutoTokenizer.from_pretrained(LLaMAPath, add_bos_token=True, add_eos_token=True)
, we get[128000, 6151, 11, 1268, 527, 499, 3432, 30, 128001]
,when using
tokenizer = AutoTokenizer.from_pretrained(LLaMAPath, add_bos_token=False, add_eos_token=False)
, we get[6151, 11, 1268, 527, 499, 3432, 30]
,