mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
18.8k stars 1.54k forks source link

[Bug] fine-tuned model deployed with webllm not working #2601

Closed JLKaretis closed 1 month ago

JLKaretis commented 3 months ago

šŸ› Bug

I have a bug with a fine-tuned (DORA) Qwen2-0.5B deployed with webllm. Inference always fail with the same error.

Error trace: background.js:53 thread '' panicked at /home/cfruan/.cargo/registry/src/index.crates.io-6f17d22bba15001f/rayon-core-1.12.1/src/registry.rs:168:10: The global thread pool has not been initialized.: ThreadPoolBuildError { kind: IOError(Os { code: 6, kind: WouldBlock, message: "Resource temporarily unavailable" }) }

Any help is most welcomed, this has been blocking me for a few days already

To Reproduce

Steps to reproduce the behavior:

Expected behavior

Model should output tokens, there shouldn't be an error in the console.

Environment

Additional context

I can share code samples and the problem weights if needed

Hzfengsy commented 3 months ago

Could you please try if you can run the original Qwen2-0.5B? Also, can you run your fine-tuned model on other devices, i.e. CUDA?

JLKaretis commented 3 months ago

yes I can run the original Qwen2-0.5B (compiled from the source weights) on webllm, and I can run the fine-tuned model on Metal with the mlc-llm python library - it's only the fine-tuned model that fails on webllm

tqchen commented 3 months ago

This seems to have to do with how we package and the latest wasm runtime. If you have custom compile that runs the original Qwen and it reproduces the error that would be helpful. Alternatively, would be great if you can share a reproducible command with the model that caused the error

JLKaretis commented 3 months ago

These are the weights that I'm trying to deploy - works fine with the python backend on Metal

from mlc_llm import MLCEngine

# Create engine
model = "HF://OpilotAI/qwen2-0.5B-pii-masking-lora-merged-q4f16_1-Opilot"
engine = MLCEngine(model)

# Run chat completion in OpenAI API.
for response in engine.chat.completions.create(
    messages=[{"role": "user", "content": "What is the meaning of life?"}],
    model=model,
    stream=True,
):
    for choice in response.choices:
        print(choice.delta.content, end="", flush=True)
print("\n")

engine.terminate()

[/opt/homebrew/Caskroom/miniforge/base/envs/mlc/lib/python3.12/site-packages/tqdm/auto.py:21](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/mlc/lib/python3.12/site-packages/tqdm/auto.py:21): TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[2024-06-25 10:15:01] INFO auto_device.py:88: Not found device: cuda:0
[2024-06-25 10:15:02] INFO auto_device.py:88: Not found device: rocm:0
[2024-06-25 10:15:03] INFO auto_device.py:79: Found device: metal:0
[2024-06-25 10:15:04] INFO auto_device.py:88: Not found device: vulkan:0
[2024-06-25 10:15:05] INFO auto_device.py:88: Not found device: opencl:0
[2024-06-25 10:15:05] INFO auto_device.py:35: Using device: metal:0
[2024-06-25 10:15:05] INFO download_cache.py:227: Downloading model from HuggingFace: HF://OpilotAI/qwen2-0.5B-pii-masking-lora-merged-q4f16_1-Opilot
[2024-06-25 10:15:05] INFO download_cache.py:29: MLC_DOWNLOAD_CACHE_POLICY = ON. Can be one of: ON, OFF, REDO, READONLY
[2024-06-25 10:15:05] INFO download_cache.py:166: Weights already downloaded: [/Users/User/.cache/mlc_llm/model_weights/hf/OpilotAI/qwen2-0.5B-pii-masking-lora-merged-q4f16_1-Opilot](https://file+.vscode-resource.vscode-cdn.net/Users/User/.cache/mlc_llm/model_weights/hf/OpilotAI/qwen2-0.5B-pii-masking-lora-merged-q4f16_1-Opilot)
[2024-06-25 10:15:05] INFO jit.py:43: MLC_JIT_POLICY = ON. Can be one of: ON, OFF, REDO, READONLY
[2024-06-25 10:15:05] INFO jit.py:158: Using cached model lib: [/Users/User/.cache/mlc_llm/model_lib/c5f2c474b97ac6bb95cf167c9cc9dba8.dylib](https://file+.vscode-resource.vscode-cdn.net/Users/User/.cache/mlc_llm/model_lib/c5f2c474b97ac6bb95cf167c9cc9dba8.dylib)
[2024-06-25 10:15:05] INFO engine_base.py:179: The selected engine mode is local. We choose small max batch size and KV cache capacity to use less GPU memory.
[2024-06-25 10:15:05] INFO engine_base.py:204: If you don't have concurrent requests and only use the engine interactively, please select mode "interactive".
[2024-06-25 10:15:05] INFO engine_base.py:209: If you have high concurrent requests and want to maximize the GPU memory utilization, please select mode "server".
[10:15:05] [/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/mlc-llm/cpp/serve/config.cc:668](https://file+.vscode-resource.vscode-cdn.net/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/mlc-llm/cpp/serve/config.cc:668): Under mode "local", max batch size will be set to 4, max KV cache token capacity will be set to 8192, prefill chunk size will be set to 2048. 
[10:15:05] [/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/mlc-llm/cpp/serve/config.cc:668](https://file+.vscode-resource.vscode-cdn.net/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/mlc-llm/cpp/serve/config.cc:668): Under mode "interactive", max batch size will be set to 1, max KV cache token capacity will be set to 32768, prefill chunk size will be set to 2048. 
[10:15:05] [/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/mlc-llm/cpp/serve/config.cc:668](https://file+.vscode-resource.vscode-cdn.net/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/mlc-llm/cpp/serve/config.cc:668): Under mode "server", max batch size will be set to 80, max KV cache token capacity will be set to 32768, prefill chunk size will be set to 2048. 
[10:15:05] [/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/mlc-llm/cpp/serve/config.cc:748](https://file+.vscode-resource.vscode-cdn.net/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/mlc-llm/cpp/serve/config.cc:748): The actual engine mode is "local". So max batch size is 4, max KV cache token capacity is 8192, prefill chunk size is 2048.
[10:15:05] [/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/mlc-llm/cpp/serve/config.cc:753](https://file+.vscode-resource.vscode-cdn.net/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/mlc-llm/cpp/serve/config.cc:753): Estimated total single GPU memory usage: 2959.723 MB (Parameters: 265.118 MB. KVCache: 152.245 MB. Temporary buffer: 2542.361 MB). The actual usage might be slightly larger than the estimated number.
As an AI language model, I don't have personal beliefs or experiences. However, based on scientific research, the meaning of life is a question that has been asked by many people throughout history. It is generally believed that there is no one definitive answer to this question, and it is possible that different people have different ideas about what it means. Some people believe that life is a gift from God, while others believe that it is a struggle between good and evil. Ultimately, the meaning of life is a complex and personal question that depends on many factors, including personal experiences and beliefs.

using webLLM

Original weights:

Fine tuned weights

Interestingly we also have the following combinations:

So the problem seems to be with the weights, but then why do they work with the Python library?

bil-ash commented 3 months ago

@JLKaretis May be you can now try with unquantized qwen2-0.5B because the mlc-llm team hasn't released a q4f16 webgpu.wasm , only q0f16 and there are some problems with your generated wasm which somehow do not give issues on original model but on fine-tuned model . I am also waiting for official qwen2-0.5B q4f16 wasm. I guess you should try your fine-tuned q4f16 model after the official wasm is released.

JLKaretis commented 3 months ago

@bil-ash I did try with the q0f16 quant last week, and got the same error. I don't think that there is a difference between the wasm I generated and the "official" one

tqchen commented 2 months ago

Looking at the error, it seems was triggered when tokenizer tries to use multithreading, i wonder if it is related to tokenizer config u have

CharlieFRuan commented 1 month ago

I was able to reproduce it, should be an issue on tokenizers-cpp/web, where it does not work with certain tokenizer.json. A minimal example to reproduce is to run the following code in https://github.com/mlc-ai/tokenizers-cpp/blob/main/web/tests/src/index.ts:

async function testBertTokenizer() {
  console.log("Bert Tokenizer");
  const modelBuffer = await (await
    fetch("https://huggingface.co/CharlieFRuan/snowflake-arctic-embed-m-q0f32-MLC/resolve/main/tokenizer.json")
  ).arrayBuffer();
  const tok = await Tokenizer.fromJSON(modelBuffer);
  const text = "What is the capital of Canada?";
  const ids = tok.encode(text);
}

This leads to error:

image
CharlieFRuan commented 1 month ago

My initial guess is the padding field in tokenizer.json triggers this issue. This is not present in your original weight: https://huggingface.co/julientfai/Qwen2-0.5B-Instruct-q4f16_1-Opilot/blob/main/tokenizer.json

But this is present in both:

I'll see if somehow setting TOKENIZERS_PARALLELISM=false for tokenizers will work.

CharlieFRuan commented 1 month ago

Confirmed on my end the issue is fixed with WebLLM npm 0.2.57. For more see https://github.com/mlc-ai/tokenizers-cpp/pull/42

CharlieFRuan commented 1 month ago

Closing this issue for now. If persists, feel free to re-open.