OoriData / Toolio

AI API implementation for Mac which supports tool-calling & other structured LLM response generation (e.g. conform to JSON schema)
86 stars 2 forks source link

Weirdness with tokenization in Phi-3 #12

Open uogbuji opened 1 month ago

uogbuji commented 1 month ago

Server:

toolio_server --model=mlx-community/Phi-3-mini-128k-instruct-4bit

Client:

toolio_request --apibase="http://localhost:8000" --prompt='What is the average airspeed of an unladen swallow?'

You can run the above any number of times, but as soon as you run a version that tries to use a prior prompt cache:

toolio_request --apibase="http://localhost:8000" --prompt='What is the average airspeed of an unladen swallow? Where have I heard that before?'

It blows up. Server exception tail:

  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/cli/server.py", line 271, in post_v1_chat_completions_impl
    for result in app.state.model.completion(
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/schema_helper.py", line 296, in completion
    logits, cache = self._evaluate_prompt(
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/schema_helper.py", line 92, in _evaluate_prompt
    logits = self.model(mx.array(tokens)[None], cache)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 202, in __call__
    out = self.model(inputs, cache)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 184, in __call__
    h = layer(h, mask, c)
        ^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 148, in __call__
    r = self.self_attn(self.input_layernorm(x), mask, cache)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 110, in __call__
    output = mx.fast.scaled_dot_product_attention(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Shapes (1,32,9,24) and (9,9) cannot be broadcast.

Modified schema_helper.py for a trace

    def _evaluate_prompt(
        self, prompt: list[int], prior_prompt: list[int] = None, prior_cache=None
    ):
        if prior_prompt:
            i = 0
            for i, t in enumerate(prior_prompt):
                # Need to leave at least one token to evaluate because we don't
                # save the past logits.
                if i >= len(prompt) - 1 or prompt[i] != t:
                    break
            cache = prior_cache
            for layer_cache in cache:
                layer_cache.reuse(len(prompt), i)
            tokens = prompt[i:]
            print('CACHED', tokens, prompt)
        else:
            cache = ReusableKVCache.for_model(self.model)
            tokens = prompt
            print('UNCACHED', tokens)

        logits = self.model(mx.array(tokens)[None], cache)
        return logits, cache

First run of the shorter prompt displays:

UNCACHED [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]

Already notice the repeated 32007, which is the Phi-3 '<|end|>' token. This is probably not good. Identical run again:

CACHED [32007] [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]

Expected logic, with nothing but that end token post-cache. Now the longer prompt:

CACHED [6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007] [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]

End prompt is re-doubled.

At this point I don't know whether this tokenizer oddness is what leads to the shape error, but it's a start for investigating.

uogbuji commented 1 month ago

Quick look at the Phi-3 tokenizer:

import transformers

tokenizer = transformers.AutoTokenizer.from_pretrained(
    # Should be same tokenizer as microsoft/Phi-3-mini-128k-instruct-4bit
    'mlx-community/Phi-3-mini-128k-instruct-4bit'
)
S = 'Hello<|end|>'
ids = tokenizer.encode(S, add_special_tokens=False)
print(ids)
S_decode = tokenizer.decode(ids)
print(repr(S_decode))

S = 'Hello<|end|>'
ids = tokenizer.encode(S, add_special_tokens=True)
print(ids)
S_decode = tokenizer.decode(ids)
print(repr(S_decode))

Output:

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[15043, 32007]
'Hello<|end|>'
[15043, 32007]
'Hello<|end|>'

The 'Special tokens' warning comes up as soon as you load the tokenizer, and has nothing to do with , add_special_tokens=True|False later on.

repr of tokenizer:

LlamaTokenizerFast(name_or_path='mlx-community/Phi-3-mini-128k-instruct-4bit', vocab_size=32000, model_max_length=131072, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '<|end|>', 'unk_token': '<unk>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
        0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        2: AddedToken("</s>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=False),
        32000: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        32001: AddedToken("<|assistant|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32002: AddedToken("<|placeholder1|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32003: AddedToken("<|placeholder2|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32004: AddedToken("<|placeholder3|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32005: AddedToken("<|placeholder4|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32006: AddedToken("<|system|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32007: AddedToken("<|end|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32008: AddedToken("<|placeholder5|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32009: AddedToken("<|placeholder6|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
        32010: AddedToken("<|user|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
}

So yes, Phi-3 uses the Llama tokenizer. Notice that the special tokens are added with rstrip=True, i.e. with ws normalization.

uogbuji commented 1 month ago

A trimmed down repro case:

import mlx.core as mx
from toolio.schema_helper import Model, ReusableKVCache

m = Model()
m.load('mlx-community/Phi-3-mini-128k-instruct-4bit')
from mlx_lm.models.base import KVCache
cache = ReusableKVCache.for_model(m.model)
tokens1 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]
logits = m.model(mx.array(tokens1)[None], cache)

cached_prompt = logits
, prior_prompt: list[int] = None, prior_cache=None

tokens2 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]
tokens2_postcache = [6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]
for layer_cache in cache:
    layer_cache.reuse(len(tokens2), len(tokens2)-1)

logits = m.model(mx.array(tokens2_postcache)[None], cache)

Result: ValueError: Shapes (1,32,9,32) and (9,9) cannot be broadcast.

Note: just blindly replacing all cases of 32007, 32007 merely tweaked the error: ValueError: Shapes (1,32,8,30) and (8,8) cannot be broadcast.

cache = ReusableKVCache.for_model(m.model)
tokens1 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007]
logits = m.model(mx.array(tokens1)[None], cache)

cached_prompt = logits
, prior_prompt: list[int] = None, prior_cache=None

tokens2 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007]
tokens2_postcache = [6804, 505, 306, 6091, 393, 1434, 29973, 32007]
for layer_cache in cache:
    layer_cache.reuse(len(tokens2), len(tokens2)-1)

logits = m.model(mx.array(tokens2_postcache)[None], cache)
uogbuji commented 1 month ago

For now I've got around this by disabling cache prompting by default. I'll leave the ticket open, though, because it would be nice to work a proper fix.