ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.5k stars 791 forks source link

Tweaks to run dspy-produced calls to the server, with gemma template. #810

Closed namin closed 3 weeks ago

namin commented 1 month ago

Hi,

I am not sure if that is of general interest, but I thought I'd submit it in case it's useful.

I had to tweak the server a little to get it to interact with dspy following the following comment https://github.com/stanfordnlp/dspy/issues/385#issuecomment-1998939936.

The two tweaks are to (1) relax the validation and convert from int to float as needed, (2) to work around some chat template not taking the system role (such as Google Gemma's chat template).

can try it out with:

python -m server --model mlx-community/gemma-1.1-7b-it-4bit --port 1143

modulo patching the relative imports in server.py

-from .tokenizer_utils import TokenizerWrapper
-from .utils import generate_step, load
+from mlx_lm.tokenizer_utils import TokenizerWrapper
+from mlx_lm.utils import generate_step, load

and then, ont the dspy side:

import dspy
lm = dspy.OpenAI(model_type="chat", api_base="http://localhost:11434/v1/", api_key="not_needed", max_tokens=250)
lm("hello")

Thanks!

awni commented 1 month ago

Cool! Btw you shouldn't need to change the imports if you run it as a package like so:

mlx_lm.server --model mlx-community/gemma-1.1-7b-it-4bit --port 1143
namin commented 4 weeks ago

Done with your comments. Let me know if you prefer me to rebase. Thanks!

namin commented 4 weeks ago

I tried a different application (OpenDevin), and to get it working with this server and gemma, I had to convert from OpenAI to Gemma messages as follows:

def openai2gemma(messages: List[dict]):
    print("before:", pretty_json(messages))
    for m in messages:
        m["role"] = m["role"].replace("system", "user")
    t_messages = []
    for m in messages:
        if t_messages == [] or t_messages[-1]["role"] != m["role"]:
            t_messages.append(m)
        else:
            t_messages[-1]["content"] = "\n" + m["content"]
    print("after:", pretty_json(t_messages))
    return t_messages

In short, replace system roles with user role, and make sure roles are strictly alternating.

Transforming from OpenAI messages to conform to a specific chat template is probably out of scope of this PR...

namin commented 4 weeks ago

After further experimentation... the openai2gemma function seems useful for other models too, like llama3. In opendevin, mlx-community/Meta-Llama-3-70B-Instruct-4bit will silently work without the conversion function, but it produces questionable results.

namin commented 3 weeks ago

@ awni a much better solution to deal with chat templates occurred to me. Just provide a --chat-template flag that is passed to the tokenizer. The provider is then responsible to adhere to the OpenAI chat model. For gemma, this template seems to work:

--chat-template "{{ bos_token }}{% set extra_system = '' %}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{% if role == 'system' %}{% set extra_system = extra_system + message['content'] %}{% else %}{% if role == 'user' and extra_system %}{% set message_system = 'System: ' + extra_system %}{% else %}{% set message_system = '' %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message_system + message['content'] | trim + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"

It basically does the convoluted logic in the jinja2 code.

What do you think?

Thanks!

awni commented 3 weeks ago

I think that's great! Though I'm curious, where did you get the chat template for Gemma from? Is there a standard process for coming up with those?

namin commented 3 weeks ago

Cool :)

I started with the original template, found like this:

% python
>>> from mlx_lm import load
>>> 
>>> model, tokenizer = load("mlx-community/gemma-1.1-7b-it-4bit")
>>> tokenizer.chat_template
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"

and made it lenient instead of throwing exceptions when a system role is found or when roles do not alternate.

So in particular, like I was doing in the workaround in Python here earlier, I fold in the system messages to the next user message, prepending System: and all the system contents accumulated in the meantime. That seems to work for most use cases. I don't bother checking alternation of roles at the moment.

namin commented 3 weeks ago

I did some debugging on my gemma template, and I ran into the issue/misconception in jinja2 that variables outside a loop cannot be updated inside! I used a namespace as a workaround, as suggested here: https://stackoverflow.com/questions/46939756/setting-variable-in-jinja-for-loop-doesnt-persist-between-iterations.

So this is my updated chat-template parameter:

"{{ bos_token }}{% set ns = namespace(extra_system='') %}{% for message in messages %}{% set role = message['role'] %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% endif %}{% if (role == 'system') %}{% set ns.extra_system = ns.extra_system + message['content'] %}{% else %}{% set message_system = '' %}{% if (role == 'user') %}{% if (ns.extra_system == '') %}{% else %}{% set message_system = 'System: ' + ns.extra_system + '\\n' %}{% set ns.extra_system = '' %}{% endif %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message_system + message['content'] | trim + '<end_of_turn>\\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}"

For readability:

{{ bos_token }}
{% set ns = namespace(extra_system='') %}
{% for message in messages %}
  {% set role = message['role'] %}
  {% if (message['role'] == 'assistant') %}
    {% set role = 'model' %}
  {% endif %}
  {% if (role == 'system') %}
    {% set ns.extra_system = ns.extra_system + message['content'] %}
  {% else %}
    {% set message_system = '' %}
    {% if (role == 'user') %}
      {% if (ns.extra_system == '') %}
      {% else %}
        {% set message_system = 'System: ' + ns.extra_system + '\n' %}
        {% set ns.extra_system = '' %}
      {% endif %}
    {% endif %}
    {{ '<start_of_turn>' + role + '\n' + message_system + message['content'] | trim + '<end_of_turn>\n' }}
  {% endif %}
{% endfor %}
{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}