mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
18.67k stars 1.52k forks source link

[Bug] Compiling custom fine-tuned Mistral 7B leads to TVM BlockBuilder error #1794

Closed kkoehncke closed 6 months ago

kkoehncke commented 7 months ago

🐛 Bug

I have a custom fine-tuned Mistral 7B (with lots of additional added special tokens) that is aimed at generating answers up to 2048 sequence length.

At first, I went through the normal MLC flow as described in the documentation; note I explicitedly did not apply quantization:

mlc_chat convert_weight <hf_model_path> --quantization q0f16 \
    -o <output_path>

mlc_chat gen_config <hf_model_path> --quantization q0f16 \
--conv-template mistral_default -o <output_path> 

mlc_chat compile <output_path>/mlc-chat-config.json \
    --device cuda -o <output_path>.so

When testing out the new compiled model with:

config = ChatConfig(max_batch_size=1, max_gen_len=500, temperature=0.0, top_p=0.0) #, conv_config=conv_config)
cm = ChatModule(
    model=...
    model_lib_path=...,
    chat_config=config
)
# Generate a response for a given prompt
output = cm.generate(
    prompt=TEST_PROMPT,
    progress_callback=StreamToStdout(callback_interval=2),
)

The model was just repeating the same token over and over again infinitely.

After reading https://github.com/mlc-ai/mlc-llm/issues/978 & https://github.com/mlc-ai/mlc-llm/issues/802, I wanted to try using the build.py script to try compiling the model with Cutlass disabled to see if that would resolve the infinite repeating token problem. But then I ran into some TVM error when building. Any help would be appreciated, thanks!

To Reproduce

python build.py --model /data/ML_Workdir/models/mistral7b_1e-5_warmup_100/checkpoint-13129/ --quantization q0f16 --artifact-path /data/dist/mistral7b_1e-5_warmup_100-q0f16-MLC --max-seq-len 2048 --target cuda --no-cutlass-norm --no-cutlass-attn --use-safetensors --build-model-only
Using path "/data/ML_Workdir/models/mistral7b_1e-5_warmup_100/checkpoint-13129" for model "checkpoint-13129"
Target configured: cuda -keys=cuda,gpu -arch=sm_80 -max_num_threads=1024 -max_shared_memory_per_block=49152 -max_threads_per_block=1024 -registers_per_block=65536 -thread_warp_size=32
Traceback (most recent call last):
  File "/data/mlc-llm/mlc_llm/build.py", line 47, in <module>
    main()
  File "/data/mlc-llm/mlc_llm/build.py", line 43, in main
    core.build_model_from_args(parsed_args)
  File "/data/mlc-llm/mlc_llm/core.py", line 859, in build_model_from_args
    mod, param_manager, params, model_config = model_generators[args.model_category].get_model(
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/mlc-llm/mlc_llm/relax_model/mistral.py", line 1016, in get_model
    create_encoding_func(bb, param_manager, config, args.quantization, sep_embed)
  File "/data/mlc-llm/mlc_llm/relax_model/mistral.py", line 866, in create_encoding_func
    logits, key_value_cache = model(
                              ^^^^^^
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/testing/nn.py", line 263, in __call__
    return self.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/frontend/nn/subroutine.py", line 87, in new_forward
    return old_forward(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/mlc-llm/mlc_llm/relax_model/mistral.py", line 766, in forward
    hidden_states, key_value_cache = self.model(
                                     ^^^^^^^^^^^
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/testing/nn.py", line 263, in __call__
    return self.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/frontend/nn/subroutine.py", line 87, in new_forward
    return old_forward(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/mlc-llm/mlc_llm/relax_model/mistral.py", line 720, in forward
    hidden_states, key_value_cache = decoder_layer(
                                     ^^^^^^^^^^^^^^
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/testing/nn.py", line 263, in __call__
    return self.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/frontend/nn/subroutine.py", line 87, in new_forward
    return old_forward(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/mlc-llm/mlc_llm/relax_model/mistral.py", line 581, in forward
    hidden_states, present_key_value = self.self_attn(
                                       ^^^^^^^^^^^^^^^
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/testing/nn.py", line 263, in __call__
    return self.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/frontend/nn/subroutine.py", line 87, in new_forward
    return old_forward(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/mlc-llm/mlc_llm/relax_model/mistral.py", line 489, in forward
    key, value, updated_key_value = self.interleave_kv(
                                    ^^^^^^^^^^^^^^^^^^^
  File "/data/mlc-llm/mlc_llm/relax_model/mistral.py", line 346, in interleave_kv
    relax.call_pure_packed(
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/utils.py", line 173, in wrapper
    bound = sig.bind(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/inspect.py", line 3212, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/conda/envs/mlc-chat-venv/lib/python3.11/inspect.py", line 3201, in _bind
    raise TypeError(
TypeError: got an unexpected keyword argument 'args'
[00:12:53] /workspace/tvm/src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!

Expected behavior

Environment

Additional context

MasterJH5574 commented 6 months ago

Hey @kkoehncke sorry for being late. That issue should have been fixed I think. Could you update the pip packages and try again?

kkoehncke commented 6 months ago

Yes working now, thanks a lot! Will close