huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
8.34k stars 943 forks source link

Tree-attention for medusa #2073

Open hustxiayang opened 2 weeks ago

hustxiayang commented 2 weeks ago

Feature request

Hi, in Medusa's paper, they adopt tree-attention and use typical sampling to increase the speedups, but in current code bases, I think it only use argmax() and no tree-attention. Would you add this feature.

Motivation

I benchmarked some pre-trained medusa models, using tgi, the performance is not ideal.

Your contribution

I'm glad to contribute on this feature you think it's good

LysandreJik commented 2 weeks ago

Hey @hustxiayang, IIRC, applying the tree-attention results in significantly more compute than using argmax, so the throughput cost is too high to justify the value.

Don't quote me on this but IIRC without the tree the speedup is 2x, while with the tree there is a maximum of ~3x speedup but throughput significantly decreases (divided by 10-20x).

To confirm with @Narsil eventually, hope it sheds some light on the question in the meantime.

hustxiayang commented 2 weeks ago

Hi, I suppose you observe speedups from one of:

text-generation-inference/gemma-7b-it-medusa text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa text-generation-inference/Mistral-7B-Instruct-v0.2-medusa

If that is the case, which dataset(s) did you use?

Another question: do you plan to support typical sampling mentioned in their paper?