Open uogbuji opened 1 month ago
Quick look at the Phi-3 tokenizer:
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained(
# Should be same tokenizer as microsoft/Phi-3-mini-128k-instruct-4bit
'mlx-community/Phi-3-mini-128k-instruct-4bit'
)
S = 'Hello<|end|>'
ids = tokenizer.encode(S, add_special_tokens=False)
print(ids)
S_decode = tokenizer.decode(ids)
print(repr(S_decode))
S = 'Hello<|end|>'
ids = tokenizer.encode(S, add_special_tokens=True)
print(ids)
S_decode = tokenizer.decode(ids)
print(repr(S_decode))
Output:
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[15043, 32007]
'Hello<|end|>'
[15043, 32007]
'Hello<|end|>'
The 'Special tokens' warning comes up as soon as you load the tokenizer, and has nothing to do with , add_special_tokens=True|False
later on.
repr
of tokenizer
:
LlamaTokenizerFast(name_or_path='mlx-community/Phi-3-mini-128k-instruct-4bit', vocab_size=32000, model_max_length=131072, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '<|end|>', 'unk_token': '<unk>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False), added_tokens_decoder={
0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
2: AddedToken("</s>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=False),
32000: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
32001: AddedToken("<|assistant|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32002: AddedToken("<|placeholder1|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32003: AddedToken("<|placeholder2|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32004: AddedToken("<|placeholder3|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32005: AddedToken("<|placeholder4|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32006: AddedToken("<|system|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32007: AddedToken("<|end|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32008: AddedToken("<|placeholder5|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32009: AddedToken("<|placeholder6|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32010: AddedToken("<|user|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
}
So yes, Phi-3 uses the Llama tokenizer. Notice that the special tokens are added with rstrip=True
, i.e. with ws normalization.
A trimmed down repro case:
import mlx.core as mx
from toolio.schema_helper import Model, ReusableKVCache
m = Model()
m.load('mlx-community/Phi-3-mini-128k-instruct-4bit')
from mlx_lm.models.base import KVCache
cache = ReusableKVCache.for_model(m.model)
tokens1 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]
logits = m.model(mx.array(tokens1)[None], cache)
cached_prompt = logits
, prior_prompt: list[int] = None, prior_cache=None
tokens2 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]
tokens2_postcache = [6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]
for layer_cache in cache:
layer_cache.reuse(len(tokens2), len(tokens2)-1)
logits = m.model(mx.array(tokens2_postcache)[None], cache)
Result: ValueError: Shapes (1,32,9,32) and (9,9) cannot be broadcast.
Note: just blindly replacing all cases of 32007
, 32007
merely tweaked the error: ValueError: Shapes (1,32,8,30) and (8,8) cannot be broadcast.
cache = ReusableKVCache.for_model(m.model)
tokens1 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007]
logits = m.model(mx.array(tokens1)[None], cache)
cached_prompt = logits
, prior_prompt: list[int] = None, prior_cache=None
tokens2 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007]
tokens2_postcache = [6804, 505, 306, 6091, 393, 1434, 29973, 32007]
for layer_cache in cache:
layer_cache.reuse(len(tokens2), len(tokens2)-1)
logits = m.model(mx.array(tokens2_postcache)[None], cache)
For now I've got around this by disabling cache prompting by default. I'll leave the ticket open, though, because it would be nice to work a proper fix.
Server:
Client:
You can run the above any number of times, but as soon as you run a version that tries to use a prior prompt cache:
It blows up. Server exception tail:
Modified
schema_helper.py
for a traceFirst run of the shorter prompt displays:
Already notice the repeated
32007
, which is the Phi-3 '<|end|>' token. This is probably not good. Identical run again:Expected logic, with nothing but that end token post-cache. Now the longer prompt:
End prompt is re-doubled.
At this point I don't know whether this tokenizer oddness is what leads to the shape error, but it's a start for investigating.