Closed apoorvumang closed 10 months ago
tagging @gante since you have recently worked on lookahead decoding
Hi @apoorvumang 👋
First of all, thank you for creating this clever strategy and for sharing it openly! It's simple and elegant, which makes it really great.
I've been thinking about it from a software point of view. The core functionality (find_candidate_pred_tokens
) is simple, and it reuses the core of assisted_generation
. I'm also seeing more techniques that speed up LLMs through the generation of candidate sequences. As such, here's my proposal:
assisted_generation
into a generalist decoding technique that accepts an arbitrary function to generate candidates. assisted_generation
would be a variant of this function, as is your technique.GenerationConfig
class, defaulting to None
b. When the parameters above are non-None
, your technique would get triggered in generate
, using the same pattern
c. After the code is in a nearly ready state, we'll do some benchmarks over different tasks and share in social mediaDoes it sound good to you? 🤗
(LMK if you'd like further pointers)
Sounds good! I will try to read up and implement it
Agree on the assisted_generation
refactoring as well - maybe we could even have user provided assistant_function
? (but that's a software decision I'm not qualified to make)
@apoorvumang #27750 has the new assisted_decoding
structure. It is still subject to review, but you can now have a concrete idea of what I had in mind :)
Adding your technique on top of it should be straightforward!
Thanks @gante ! Will look into the the refactored code now. I think I should be able to get something running by tonight (IST)
I have made a working implementation here, based off of #27750 : https://github.com/apoorvumang/transformers/tree/prompt_lookup_decoding . Should I start a PR with it?
Also, if you suggest any benchmarks/benchmarking code, I can help with that. I have access to A100 40GB GPU and M1 Max 32GB @gante
@apoorvumang Yes, open a PR! I can add a few suggestions even before #27750 is merged :)
My advice for benchmarks would be the following: users love it when a certain method works well with little to no hyperparameters. At the moment, I see two hyperparameters -- prompt_lookup_num_tokens
and prompt_lookup_max_matching_ngram
. I'd run a few benchmarks over a few datasets changing these hyperparameters to find whether we can:
a) get away with only one hyperparameter OR
b) set an update heuristic that gets the best hyperparameters for the input at hand (through the update_candidate_strategy
method)
If you find a way to make a) or b) work, the technique would become more user-friendly, and thus with a higher chance of being correctly used. For us, transformers
maintainers, having fewer flags is also great!
After we settle on a final implementation, I can validate the benchmarks on different devices (e.g. a T4, a 3090, ...). Given the simplicity of the technique, I suspect the results will be mostly hardware agnostic on GPU :)
Started PR here: https://github.com/huggingface/transformers/pull/27775/commits . Please do leave suggestions @gante
I will start some benchmarking on my side to find optimal hyperparameters (or update schedules). Maybe both of these can be best tuned using just a default value + update schedule, and if user wants to really change default value they can go instantiate and provide a PromptLookupCandidateGenerator
with new params.
Will get back once I start some tests. I will be trying on some standard summarization, QA and maybe look for a code editing sort of dataset.
There is significant difference between greedy and sampling when summarizing, but there are still gains. Proper analysis of the phenomenon would be a paper-worthy effort probably.
I will try to run a similar thing for code editing as well. If you think there's something I could try pls let me know.
One question @gante : Is the most popular method greedy or sampling (I would assume greedy since its the default, but I know sampling is better for quality)? If I could optimize for only one of these, which one should be the 'default'?
If I could optimize for only one of these, which one should be the 'default'?
Naive question/input here.. but assuming you can figure the optimisations, and they don't apply equally to both, would it be possible to have 2 settings for it? One when used with greedy and one when used with sampling? Even if that's handled automagically under the hood (or even presumably if it's exposed to users, it would be simpler than having to know the exact hyperparameters to tune?)
Thanks! Yes it can ofc - _get_candidate_generator
has access to generation_config
, which can be passed on here to check for stuff like this.
Any other thoughts/ideas @0xdevalias ?
Thanks! Yes it can ofc
@apoorvumang Awesome :)
Any other thoughts/ideas?
@apoorvumang None at this stage; was more of a 'drive by random brain spark' type moment :)
@apoorvumang @0xdevalias the preliminary results seem to point out that there is no obvious parameterization 🤔 Let's wait to see the results for coding!
Regarding sampling vs greedy: greedy is the default for legacy reasons, sampling is by far the most popular with chat LLMs :) tasks like summarization, translation, or automatic speech recognition tend to use greedy decoding or beam search, though.
Finally, regarding default values: we'll have to default the values to None
, so we can detect whether the user wants to use it or not. We have a few default values for legacy reasons, but the defaults should be set at a model level (with the generation_config.json
). This does not prevent us, however, from suggesting values in the parameters' docstring 🤗
Here's using mt-bench, only 2nd turn code
All 80 samples from mt-bench, 2nd turn only.
All 80 samples from mt-bench, 2nd turn only.
Hi @apoorvumang – Thanks for sharing your great work!
Two quick questions:
@keyboardAnt the error bars are usually the standard deviation of the measurement, which is a centered (and symmetric) moment -- it does not denote the minimum/maximum of a measurement, nor a range between percentiles.
As such, I'm reading it as a long-tailed distribution. Some speedups are huge (e.g. 5x), while most are moderate (e.g. 1.5x)
Hi @keyboardAnt , thank you!
Here's an example of this phenomenon, courtesy ChatGPT
PS: Sorry for the delay in working on this PR - I will try to work on it this weekend
@gante, @apoorvumang, yes. Because of the high variance, we better consider the minimal tokens/sec rate. This could ensure the long tail is one-sided. Otherwise, it might suggest a slowdown.
@keyboardAnt Could you please expand on what you mean? Like we should look for configs with a good lower bound for tokens/sec rather than a good average?
@apoorvumang, my suggestion is to measure speedup
. That is
speedup := (The ratio of tokens per second with PLD) / (The ratio of tokens per second without PLD)
where with-PLD and without-PLD share the same variables (e.g., prompt, target model, GPU device). We want to show that speedup >> 1
in most cases, and to rule out the possibility that speedup < 1
(i.e., a slowdown). The visualizations you shared do not rule out the possibility that speedup < 1
.
We must measure speedup
in varied configurations so we can better understand it. Each configuration has a unique prompt, target model, or (max_matching_ngram, num_token_output)
hyperparameter. Visualizing the distribution of speedup
and calculating its harmonic mean can help.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@apoorvumang, my suggestion is to measure
speedup
. That isspeedup := (The ratio of tokens per second with PLD) / (The ratio of tokens per second without PLD)
where with-PLD and without-PLD share the same variables (e.g., prompt, target model, GPU device). We want to show that
speedup >> 1
in most cases, and to rule out the possibility thatspeedup < 1
(i.e., a slowdown). The visualizations you shared do not rule out the possibility thatspeedup < 1
.We must measure
speedup
in varied configurations so we can better understand it.
We recently released this preprint that covers (also) the question of slowdowns: https://arxiv.org/pdf/2405.14105
Our experiments show that slowdowns exist in practice (for example, if PLD is too slow or inaccurate). We also propose a novel algorithm for running PLD (or any other drafters) in parallel so that there are no slowdowns.
Feature request
Recently proposed method prompt lookup decoding, which replaces the draft model with string matching in prompt
Code: https://github.com/apoorvumang/prompt-lookup-decoding
Motivation
Your contribution
I have a not-so-well written implementation here (python notebook). I can contribute in making it better, but will need help since its my first time