simonw / llm-mlc

LLM plugin for running models using MLC
Apache License 2.0
174 stars 8 forks source link

Get token streaming working #2

Closed simonw closed 11 months ago

simonw commented 11 months ago

This proved a bit tricky, because the MLC library works based on a callback mechanism:

from mlc_chat import ChatModule
from mlc_chat.callback import StreamToStdout

cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1")
cm.generate(
   prompt="A poem about a bunny eating lunch",
   progress_callback=StreamToStdout(callback_interval=1),
)

But... LLM expects to be able to do something like this:

for chunk in cm.generate(...):
    yield chunk
simonw commented 11 months ago

I tried some fancy code to run it in a thread and turn that into an iterator via a queue... but it crashed with some kind of C crashing error!

Here's as far as I got with that:

import click
import contextlib
import httpx
import io
import json
import llm
import os
import pathlib
import sys
import subprocess
import textwrap
import threading
import queue

MODEL_URLS = {
    "Llama-2-7b-chat": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1",
    "Llama-2-13b-chat": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-13b-chat-hf-q4f16_1",
    "Llama-2-70b-chat": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-70b-chat-hf-q4f16_1",
}

def is_git_lfs_command_available():
    try:
        subprocess.run(
            ["git", "lfs"],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            check=True,
        )
        return True
    except subprocess.CalledProcessError:
        return False

def is_git_lfs_installed():
    try:
        # Run the git config command to get the filter value
        result = subprocess.check_output(
            ["git", "config", "--get", "filter.lfs.clean"], encoding="utf-8"
        ).strip()

        # If the result contains "git-lfs clean", it's likely that Git LFS is installed
        if "git-lfs clean" in result:
            return True
        else:
            return False
    except subprocess.CalledProcessError:
        # If the command fails, it's likely that the configuration option isn't set
        return False

def _ensure_models_dir():
    directory = llm.user_dir() / "llama-cpp" / "models"
    directory.mkdir(parents=True, exist_ok=True)
    return directory

def _ensure_models_file():
    directory = llm.user_dir() / "llama-cpp"
    directory.mkdir(parents=True, exist_ok=True)
    filepath = directory / "models.json"
    if not filepath.exists():
        filepath.write_text("{}")
    return filepath

@llm.hookimpl
def register_models(register):
    directory = llm.user_dir() / "mlc"
    models_dir = directory / "dist" / "prebuilt"
    for child in models_dir.iterdir():
        if child.is_dir() and child.name != "lib":
            # It's a model! Register it
            register(
                MlcModel(
                    model_id=child.name,
                    model_path=str(child.absolute()),
                )
            )

@llm.hookimpl
def register_commands(cli):
    @cli.group()
    def mlc():
        "Commands for managing MLC models"

    @mlc.command()
    def setup():
        "Finish setting up MLC, step by step"
        directory = llm.user_dir() / "mlc"
        directory.mkdir(parents=True, exist_ok=True)
        if not is_git_lfs_command_available():
            raise click.ClickException(
                "Git LFS is not installed. You should run 'brew install git-lfs' or similar."
            )
        if not is_git_lfs_installed():
            click.echo(
                "Git LFS is not installed. Should I run 'git lfs install' for you?"
            )
            if click.confirm("Install Git LFS?"):
                subprocess.run(["git", "lfs", "install"], check=True)
            else:
                raise click.ClickException(
                    "Git LFS is not installed. You should run 'git lfs install'."
                )
        # Now we have git-lfs installed, ensure we have cloned the repo
        dist_dir = directory / "dist"
        if not dist_dir.exists():
            click.echo("Downloading prebuilt binaries...")
            # mkdir -p dist/prebuilt
            (dist_dir / "prebuilt").mkdir(parents=True, exist_ok=True)
            # git clone
            git_clone_command = [
                "git",
                "clone",
                "https://github.com/mlc-ai/binary-mlc-llm-libs.git",
                str((dist_dir / "prebuilt" / "lib").absolute()),
            ]
            subprocess.run(git_clone_command, check=True)
        click.echo("Ready to install models in {}".format(directory))

    @mlc.command(
        help=textwrap.dedent(
            """
        Download and register a model from a URL

        Try one of these names:

        \b
        {}
        """
        ).format("\n".join("- {}".format(key) for key in MODEL_URLS.keys()))
    )
    @click.argument("name_or_url")
    def download_model(name_or_url):
        url = MODEL_URLS.get(name_or_url) or name_or_url
        if not url.startswith("https://"):
            raise click.BadParameter("Invalid model name or URL")
        directory = llm.user_dir() / "mlc"
        prebuilt_dir = directory / "dist" / "prebuilt"
        if not prebuilt_dir.exists():
            raise click.ClickException("You must run 'llm mlc setup' first")
        # Run git clone URL dist/prebuilt
        last_bit = url.split("/")[-1]
        git_clone_command = [
            "git",
            "clone",
            url,
            str((prebuilt_dir / last_bit).absolute()),
        ]
        subprocess.run(git_clone_command, check=True)

class MlcModel(llm.Model):
    # class Options(llm.Options):
    #     verbose: bool = False

    def __init__(self, model_id, model_path, lib_path=None):
        self.model_id = model_id
        self.model_path = model_path
        self.lib_path = lib_path
        self.chat_mod = None  # Lazy loading

    def execute(self, prompt, stream, response, conversation):
        import mlc_chat

        from mlc_chat.callback import DeltaCallback

        class QueueCallback(DeltaCallback):
            def __init__(self, q):
                self.queue = q

            def delta_callback(self, delta_message):
                self.queue.put(delta_message)

            def stopped_callback(self):
                self.queue.put(None)  # Indicate end of items

        if conversation:
            raise click.ClickException("Conversation mode is not supported yet")
        if self.chat_mod is None:
            with temp_chdir(llm.user_dir() / "mlc"):
                self.chat_mod = mlc_chat.ChatModule(model=self.model_path, lib_path=self.lib_path)

        q = queue.Queue()

        def run_prompt(chat_mod, prompt, q):
            chat_mod.generate(
                prompt=prompt.prompt,
                progress_callback=QueueCallback(q),
            )

        t = threading.Thread(target=run_prompt, args=(self.chat_mod, prompt, q))
        t.start()
        while True:
            item = q.get()
            if item is None:  # sentinel to break the loop
                break
            yield item
            q.task_done()
        t.join()

@contextlib.contextmanager
def temp_chdir(path):
    old_dir = os.getcwd()
    os.chdir(path)
    try:
        yield
    finally:
        os.chdir(old_dir)
simonw commented 11 months ago

I think I can avoid the callback thing entirely by looking at the implementation of that .generate() method:

https://github.com/mlc-ai/mlc-llm/blob/94e0109d78517a9df4ae7354f1da1ac4190a5c1d/python/mlc_chat/chat_module.py#L616-L668

Here's the relevant subset of code extracted from ChatModule.generate()

    def generate(self, prompt: str, progress_callback=None) -> str:
        self._prefill(prompt)
        # apply callback with a rate of callback_interval
        i, new_msg = 0, ""
        while not self._stopped():
            self._decode()
            if i % progress_callback.callback_interval == 0 or self._stopped():
                new_msg = self._get_message()
                progress_callback(new_msg)
            i += 1
        progress_callback(stopped=True)
        return new_msg

So I could subclass ChatModule and add my own generate_iter() method that looks similar to that.

simonw commented 11 months ago

Note that new_msg appears to be over-written each time with a longer message.

That's what the StreamToStdout class does - it subclasses DeltaCallback in https://github.com/mlc-ai/mlc-llm/blob/94e0109d78517a9df4ae7354f1da1ac4190a5c1d/python/mlc_chat/callback.py which uses a special method to figure out what has changed:

https://github.com/mlc-ai/mlc-llm/blob/94e0109d78517a9df4ae7354f1da1ac4190a5c1d/python/mlc_chat/base.py#L28-L45

def get_delta_message(curr_message: str, new_message: str) -> str:
    f_get_delta_message = tvm.get_global_func("mlc.get_delta_message")
    return f_get_delta_message(curr_message, new_message)
simonw commented 11 months ago

Got it to work!

I used this devious trick to disable all of the print() statements in that one module:

https://github.com/simonw/llm-mlc/blob/df522d602de524a81c2d6ba1ef35a89f6da41f8b/llm_mlc.py#L203-L207

I still had to use the SuppressOutput trick too, to disable one message that was being output by some C code somewhere.

I restored the SuppressOutput version that does NOT intefere with the default Python sys.stdout so that the streaming output would show up correctly.

simonw commented 11 months ago

owl