Closed simonw closed 11 months ago
I tried some fancy code to run it in a thread and turn that into an iterator via a queue... but it crashed with some kind of C crashing error!
Here's as far as I got with that:
import click
import contextlib
import httpx
import io
import json
import llm
import os
import pathlib
import sys
import subprocess
import textwrap
import threading
import queue
MODEL_URLS = {
"Llama-2-7b-chat": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1",
"Llama-2-13b-chat": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-13b-chat-hf-q4f16_1",
"Llama-2-70b-chat": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-70b-chat-hf-q4f16_1",
}
def is_git_lfs_command_available():
try:
subprocess.run(
["git", "lfs"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True,
)
return True
except subprocess.CalledProcessError:
return False
def is_git_lfs_installed():
try:
# Run the git config command to get the filter value
result = subprocess.check_output(
["git", "config", "--get", "filter.lfs.clean"], encoding="utf-8"
).strip()
# If the result contains "git-lfs clean", it's likely that Git LFS is installed
if "git-lfs clean" in result:
return True
else:
return False
except subprocess.CalledProcessError:
# If the command fails, it's likely that the configuration option isn't set
return False
def _ensure_models_dir():
directory = llm.user_dir() / "llama-cpp" / "models"
directory.mkdir(parents=True, exist_ok=True)
return directory
def _ensure_models_file():
directory = llm.user_dir() / "llama-cpp"
directory.mkdir(parents=True, exist_ok=True)
filepath = directory / "models.json"
if not filepath.exists():
filepath.write_text("{}")
return filepath
@llm.hookimpl
def register_models(register):
directory = llm.user_dir() / "mlc"
models_dir = directory / "dist" / "prebuilt"
for child in models_dir.iterdir():
if child.is_dir() and child.name != "lib":
# It's a model! Register it
register(
MlcModel(
model_id=child.name,
model_path=str(child.absolute()),
)
)
@llm.hookimpl
def register_commands(cli):
@cli.group()
def mlc():
"Commands for managing MLC models"
@mlc.command()
def setup():
"Finish setting up MLC, step by step"
directory = llm.user_dir() / "mlc"
directory.mkdir(parents=True, exist_ok=True)
if not is_git_lfs_command_available():
raise click.ClickException(
"Git LFS is not installed. You should run 'brew install git-lfs' or similar."
)
if not is_git_lfs_installed():
click.echo(
"Git LFS is not installed. Should I run 'git lfs install' for you?"
)
if click.confirm("Install Git LFS?"):
subprocess.run(["git", "lfs", "install"], check=True)
else:
raise click.ClickException(
"Git LFS is not installed. You should run 'git lfs install'."
)
# Now we have git-lfs installed, ensure we have cloned the repo
dist_dir = directory / "dist"
if not dist_dir.exists():
click.echo("Downloading prebuilt binaries...")
# mkdir -p dist/prebuilt
(dist_dir / "prebuilt").mkdir(parents=True, exist_ok=True)
# git clone
git_clone_command = [
"git",
"clone",
"https://github.com/mlc-ai/binary-mlc-llm-libs.git",
str((dist_dir / "prebuilt" / "lib").absolute()),
]
subprocess.run(git_clone_command, check=True)
click.echo("Ready to install models in {}".format(directory))
@mlc.command(
help=textwrap.dedent(
"""
Download and register a model from a URL
Try one of these names:
\b
{}
"""
).format("\n".join("- {}".format(key) for key in MODEL_URLS.keys()))
)
@click.argument("name_or_url")
def download_model(name_or_url):
url = MODEL_URLS.get(name_or_url) or name_or_url
if not url.startswith("https://"):
raise click.BadParameter("Invalid model name or URL")
directory = llm.user_dir() / "mlc"
prebuilt_dir = directory / "dist" / "prebuilt"
if not prebuilt_dir.exists():
raise click.ClickException("You must run 'llm mlc setup' first")
# Run git clone URL dist/prebuilt
last_bit = url.split("/")[-1]
git_clone_command = [
"git",
"clone",
url,
str((prebuilt_dir / last_bit).absolute()),
]
subprocess.run(git_clone_command, check=True)
class MlcModel(llm.Model):
# class Options(llm.Options):
# verbose: bool = False
def __init__(self, model_id, model_path, lib_path=None):
self.model_id = model_id
self.model_path = model_path
self.lib_path = lib_path
self.chat_mod = None # Lazy loading
def execute(self, prompt, stream, response, conversation):
import mlc_chat
from mlc_chat.callback import DeltaCallback
class QueueCallback(DeltaCallback):
def __init__(self, q):
self.queue = q
def delta_callback(self, delta_message):
self.queue.put(delta_message)
def stopped_callback(self):
self.queue.put(None) # Indicate end of items
if conversation:
raise click.ClickException("Conversation mode is not supported yet")
if self.chat_mod is None:
with temp_chdir(llm.user_dir() / "mlc"):
self.chat_mod = mlc_chat.ChatModule(model=self.model_path, lib_path=self.lib_path)
q = queue.Queue()
def run_prompt(chat_mod, prompt, q):
chat_mod.generate(
prompt=prompt.prompt,
progress_callback=QueueCallback(q),
)
t = threading.Thread(target=run_prompt, args=(self.chat_mod, prompt, q))
t.start()
while True:
item = q.get()
if item is None: # sentinel to break the loop
break
yield item
q.task_done()
t.join()
@contextlib.contextmanager
def temp_chdir(path):
old_dir = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(old_dir)
I think I can avoid the callback thing entirely by looking at the implementation of that .generate()
method:
Here's the relevant subset of code extracted from ChatModule.generate()
def generate(self, prompt: str, progress_callback=None) -> str:
self._prefill(prompt)
# apply callback with a rate of callback_interval
i, new_msg = 0, ""
while not self._stopped():
self._decode()
if i % progress_callback.callback_interval == 0 or self._stopped():
new_msg = self._get_message()
progress_callback(new_msg)
i += 1
progress_callback(stopped=True)
return new_msg
So I could subclass ChatModule
and add my own generate_iter()
method that looks similar to that.
Note that new_msg
appears to be over-written each time with a longer message.
That's what the StreamToStdout
class does - it subclasses DeltaCallback
in https://github.com/mlc-ai/mlc-llm/blob/94e0109d78517a9df4ae7354f1da1ac4190a5c1d/python/mlc_chat/callback.py which uses a special method to figure out what has changed:
def get_delta_message(curr_message: str, new_message: str) -> str:
f_get_delta_message = tvm.get_global_func("mlc.get_delta_message")
return f_get_delta_message(curr_message, new_message)
Got it to work!
I used this devious trick to disable all of the print()
statements in that one module:
https://github.com/simonw/llm-mlc/blob/df522d602de524a81c2d6ba1ef35a89f6da41f8b/llm_mlc.py#L203-L207
I still had to use the SuppressOutput
trick too, to disable one message that was being output by some C code somewhere.
I restored the SuppressOutput
version that does NOT intefere with the default Python sys.stdout
so that the streaming output would show up correctly.
This proved a bit tricky, because the MLC library works based on a callback mechanism:
But... LLM expects to be able to do something like this: