OoriData / Toolio

AI API implementation for Mac which supports tool-calling & other structured LLM response generation (e.g. conform to JSON schema)
88 stars 3 forks source link

Gemma support (re Jinja error: "System role not supported") #1

Closed uogbuji closed 2 months ago

uogbuji commented 3 months ago

Originally reported by Mark Lord.

Repro steps:

python -m mlx_lm.convert --hf-path UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3 --mlx-path ~/.local/share/models/mlx/Gemma-2-9B-It-SPPO-Iter3-8bit -q --q-bits 8
MLXStructuredLMServer --model=$HOME/.local/share/models/mlx/Gemma-2-9B-It-SPPO-Iter3-8bit

Try a request such as:

echo 'What is the square root of 256?' > /tmp/llmprompt.txt
echo '{"tools": [{"type": "function","function": {"name": "square_root","description": "Get the square root of the given number","parameters": {"type": "object", "properties": {"square": {"type": "number", "description": "Number from which to find the square root"}},"required": ["square"]},"pyfunc": "math|sqrt"}}], "tool_choice": "auto"}' > /tmp/toolspec.json
toolio_request --apibase="http://127.0.0.1:8000" --prompt-file=/tmp/llmprompt.txt --tools-file=/tmp/toolspec.json

Resulting exception is mangled & useless; final stanza:

  File "/Users/uche/.local/venv/temp/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 1852, in raise_exception
    raise TemplateError(message)
jinja2.exceptions.TemplateError: System role not supported

Looks like llm-structured-output interpolates a system prompt, and Gemma refuses this. Indeed, from $HOME/.local/share/models/mlx/Gemma-2-9B-It-SPPO-Iter3-8bit/tokenizer_config.json:

"chat_template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",

Luckily this looks like one we can patch locally in serer.py, rather than needing it upstream in llm-structured-output.

uogbuji commented 3 months ago

OK well I got it past the system role error and now I'm into something deeper.

  File "/Users/uche/.local/venv/temp/lib/python3.11/site-packages/mlx_lm/models/gemma2.py", line 159, in __call__
    h = layer(h, mask, c)
        ^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/temp/lib/python3.11/site-packages/mlx_lm/models/gemma2.py", line 122, in __call__
    r = self.self_attn(self.input_layernorm(x), mask, cache)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/temp/lib/python3.11/site-packages/mlx_lm/models/gemma2.py", line 80, in __call__
    output = mx.fast.scaled_dot_product_attention(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Shapes (1,8,2,32,36) and (32,32) cannot be broadcast.

On the latest relevant upstream releases:

pip show mlx mlx_lm 
Name: mlx
Version: 0.15.2
Summary: A framework for machine learning on Apple silicon.
Home-page: https://github.com/ml-explore/mlx
Author: MLX Contributors
Author-email: mlx@group.apple.com
License: 
Location: /Users/uche/.local/venv/temp/lib/python3.11/site-packages
Requires: 
Required-by: mlx-lm, Toolio
---
Name: mlx-lm
Version: 0.15.0
Summary: LLMs on Apple silicon with MLX and the Hugging Face Hub
Home-page: https://github.com/ml-explore/mlx-examples
Author: MLX Contributors
Author-email: mlx@group.apple.com
License: MIT
Location: /Users/uche/.local/venv/temp/lib/python3.11/site-packages
Requires: jinja2, mlx, numpy, protobuf, pyyaml, transformers
Required-by: Toolio
uogbuji commented 3 months ago

Upstream issue: https://github.com/ml-explore/mlx-examples/issues/868

Pushed the system role fix, but will keep this open while upstream fix or workaround is in progress.

uogbuji commented 2 months ago

Forgot to tag that last commit; 078d815

uogbuji commented 2 months ago

Looks like it wasn't an upstream bug, but rather a mix-up by me. I think Gemma should be GTG.