NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.23k stars 912 forks source link

How to use Medusa to support non llama models? #1946

Open skyCreateXian opened 2 months ago

skyCreateXian commented 2 months ago

System Info

Hardware: L20 Version: 0.11.0.dev20240625 Model: Bloom7b1

Who can help?

@ncomly-nvidia @byshiue I have obtained the Medusa head for Bloom according to the official Medusa documentation, but during deployment, I need to modify bloom/model.py. I referenced llama/model.py to modify a version, but the accuracy is very poor. Therefore, I have two questions

  1. Does Medusa support deploying other models that are not llama classes?
  2. For other types of model. py, please provide reference Medusa official modification tips, like '[MODIFIED]' reference resources: https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/modeling_llama_kv.py I mainly adapted the spec_decoding-params parameter in bloom/model.py

Information

Tasks

Reproduction

1、Medusa Head for Training Bloom Model 2、Adapted spec_decoding-params parameter in bloom/modl.py

Expected behavior

nothing

actual behavior

nothing

additional notes

nothing

skyCreateXian commented 2 months ago

Is GatedMLP suitable for medusa decoration? I found two characteristics during debugging

  1. The only difference between the modified bloom/modl.py and llama lies in the MLP layer, where llama uses GatedMLP Bloom, on the other hand, uses MLP
  2. When the accept comes from the Medusa result, the last token accepted must not be aligned Is MLP layers not suitable for Medusa algorithm?
rakib-hasan commented 2 months ago

Hi @skyCreateXian , thank you for bringing this up. Agreed. We should have a documentation on steps required for making Medusa to work with other models. I think you are on the right track. The following steps should be enough to support Medusa for other models:

  1. Adding spec_decoding_params to base model (e.g. Bloom in this case).
  2. New conversion script to combine base model and the Medusa heads into TensorRT-LLM checkpoint.
  3. Changing the medusa/model.py to use the updated base model.

To answer your question on MLP, it shouldn't have any effect on Medusa. One other difference I can think of which can lead to poor accuracy is the position embedding: RoPE vs ALiBi. With Medusa, a position offset tensor is passed to the model to properly apply the position embedding to the Medusa tokens. I am not too familiar with ALiBi yet, but if it requires more than just the position offsets, then that could be the other thing that is needed to support Medusa with Bloom.

I hope this helps. Please let us know how it goes and/or if you have any more questions.

skyCreateXian commented 2 months ago

@rakib-hasan How to verify the differences caused by the position encoding algorithm? I found that forcibly modifying the "position-embeddingtotype" in convert_checkpoint: "rope_gpt_neos" did not work

sundayKK commented 2 months ago

hello, and how to use medusa to support qwen model?it's different with llama and bloom.

skyCreateXian commented 1 month ago

@sundayKK sun I adapted qwen2-7b, but found that the result was completely different from the base model, so it failed. You can follow the steps below:

  1. Adapt qwen training in Medusa to obtain training heads
  2. Modify models/medusa/modl.py to support qwen
  3. Modify models/qwen/model-py to support speculative decoding parameters
rakib-hasan commented 1 month ago

@skyCreateXian Apologies for the late response. That sounds correct. Changing the position encoding at inference time won't work as the Bloom model seems to be trained with ALiBi. The problem is that, as I understand, XQA kernel supports tree attention (required by Medusa) but doesn't support ALiBi. So, at this point, Medusa with models that uses ALiBi won't work.

@sundayKK It seems qwen2 uses RoPE so it should be compatible. I do not know that architecture details yet. But is there any other differences between qwen2 and LLaMA?

sundayKK commented 1 month ago

@skyCreateXian @rakib-hasan thanks for your answer! I'd like to try.

skyCreateXian commented 1 month ago

@skyCreateXian Apologies for the late response. That sounds correct. Changing the position encoding at inference time won't work as the Bloom model seems to be trained with ALiBi. The problem is that, as I understand, XQA kernel supports tree attention (required by Medusa) but doesn't support ALiBi. So, at this point, Medusa with models that uses ALiBi won't work.

@sundayKK It seems qwen2 uses RoPE so it should be compatible. I do not know that architecture details yet. But is there any other differences between qwen2 and LLaMA?

@rakib-hasan I tested qwen2-7b and found that it cannot be aligned on this model, so I suspect that the diff is not caused by positional encoding differences, I will continue to check