apple / ml-recurrent-drafter

Apache License 2.0
96 stars 3 forks source link

How to support other models? #2

Open skyCreateXian opened 3 months ago

skyCreateXian commented 3 months ago

@federicobucchi @wangkuiyi @tuzhucheng How to support models such as Baichuan, Bloom, or QWEN, do modeling need to be modified, and can you provide steps to support training other models?

wangkuiyi commented 1 month ago

Sorry for the late reply. Thank you @skyCreateXian for your question! We created our own modeling_llama.py instead of using the one in Hugging Face Transformers for two main reasons:

  1. At the time we developed this codebase, Hugging Face Transformers did not yet support pre-allocated KV cache. This feature is essential for us to make a fair comparison with other open-source speculative decoding methods, which rely on pre-allocated KV cache.

  2. Our modeling_llama.py accepts more parameters than Hugging Face's version. This is necessary because our implementation needs to verify one or more candidate token sequences generated by the beam search algorithm, which calls the recurrent draft model. We also employ a dynamic tree attention algorithm to eliminate duplicated prefixes in these candidate token sequences, leading to a varying number of tokens for each sequence.

If the model you want to try has the same architecture as LLaMA (e.g., DeepSeek or Mistral), no changes are needed. However, if you are working with a model that has a different architecture, you may want to consider the two points above. For more details, you can run the diff command with our modeling_llama.py with the Hugging Face version by running a diff command.

Let me know if you have more questions. I am glad to see community contributions to enable other models.