alipay / PainlessInferenceAcceleration

Creative Commons Attribution 4.0 International
283 stars 18 forks source link

In the benchmark studies, how are the draft tokens generated? #9

Open jivanph opened 8 months ago

jivanph commented 8 months ago

I read with great interest your paper 'Lookahead: An Inference Acceleration Framework for Large Language Model with Lossless Generation Accuracy'.

In essence, the paper proposes a tree data structure to verify proposed draft tokens, and in this way speed up inference.

Unfortunately, it's not clear to me from the paper how these draft tokens were generated when establishing benchmark results for LookAhead-Parallel and LookAhead-Hierarchical.

I understand the focus on the paper is on how to handle a set of draft tokens (perhaps as a single branch, perhaps in parallel, or perhaps in a hierarchical manner). But the origin of the draft tokens in the benchmark results remains unclear to me.

jivanph commented 8 months ago

A related question in regards to the benchmark studies, what was the sampling mechanism used to accept tokens? Was it greedy sampling?

zheyishine commented 8 months ago

To Q1: The draft tokens are generated from a cached trie tree (each node is a token id). Currently the trie tree is constructed from prompts and responses on-the-fly, therefore it is friendly for deployment (neither additonal assist models nor head training), and works pretty well in real-life scenarios. A bit more, we also have probed Jaccobi iteraction to yield addtional drafts and will integrate into the repo soon(even though its speedup is marginal). To Q2: Yes, we use the greedy strategy in our benchmarks.

jivanph commented 8 months ago

Thank you for your responses. I understand the tokens are passed as a tree object. But my question is, how did you choose which tokens to use in the benchmark? What do you mean by "responses on-the-fly"?

zheyishine commented 8 months ago

We choose tokens not only from responses, but also from prompts.

jivanph commented 8 months ago

Could you explain a little bit further how you choose the tokens?

zheyishine commented 8 months ago

In the benchmark, we first generate responses for samples from dev set, and put the responses into a global trie tree, then we evaluate each prompt in the test set( all the samples are different from the ones from the dev set). For each query in the test set, we first put the query into a global trie tree , and the generate tokens are also put into the trie tree on-the-fly. The tokens in the global trie tree have a chance to be chosen in the following queries.

jivanph commented 7 months ago

Thank you for your response. Could you please point me which part of the code is in charge of the verification of the trees proposed?

My main question is, how to verify that the output of regular decoding with sampling (instead of greedy decoding) coincides with PAIN decoding with sampling. How can we tell that responses made from PAIN are the same as responses obtained from regular decoding (under sampling).

zheyishine commented 7 months ago

Lines from here to here are used for vefification of tree drafts.

Our lookahead decoding can not generate exactly the same response as the generation mode SAMPLE in transformers, due to random sampling( i.e., caused by torch.multinomial). We guarantee the same distribution by following the decoding steps of generation modeSAMPLE. Our implement is aligned with the generation mode ASSISTED_GENERATION in transformers. Details can be found with lines from here to here.

fuerpy commented 7 months ago

The draft tokens are generated from a cached trie tree (each node is a token id). Currently the trie tree is constructed from prompts and responses on-the-fly, therefore it is friendly for deployment (neither additonal assist models nor head training), and works pretty well in real-life scenarios. A bit more, we also have probed Jaccobi iteraction to yield addtional drafts and will integrate into the repo soon(even though its speedup is marginal).

I'm also confused about how draft tokens were created. Do you mean that this draft tokens is generated from previous prompts records? And not generated by model sampling is it?