huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.13k stars 27.04k forks source link

Adding support for prompt lookup decoding (variant of assisted generation) #27722

Closed apoorvumang closed 10 months ago

apoorvumang commented 11 months ago

Feature request

Recently proposed method prompt lookup decoding, which replaces the draft model with string matching in prompt

Code: https://github.com/apoorvumang/prompt-lookup-decoding

Motivation

Your contribution

I have a not-so-well written implementation here (python notebook). I can contribute in making it better, but will need help since its my first time

apoorvumang commented 11 months ago

tagging @gante since you have recently worked on lookahead decoding

gante commented 11 months ago

Hi @apoorvumang 👋

First of all, thank you for creating this clever strategy and for sharing it openly! It's simple and elegant, which makes it really great.

I've been thinking about it from a software point of view. The core functionality (find_candidate_pred_tokens) is simple, and it reuses the core of assisted_generation. I'm also seeing more techniques that speed up LLMs through the generation of candidate sequences. As such, here's my proposal:

  1. I'll open a PR today to refactor the contents of assisted_generation into a generalist decoding technique that accepts an arbitrary function to generate candidates. assisted_generation would be a variant of this function, as is your technique.
  2. In parallel, you can work to add your technique on top of the generalist decoding technique with candidates: a. You'll have to define the controlling parameters of your technique in the GenerationConfig class, defaulting to None b. When the parameters above are non-None, your technique would get triggered in generate, using the same pattern c. After the code is in a nearly ready state, we'll do some benchmarks over different tasks and share in social media

Does it sound good to you? 🤗

(LMK if you'd like further pointers)

apoorvumang commented 11 months ago

Sounds good! I will try to read up and implement it

Agree on the assisted_generation refactoring as well - maybe we could even have user provided assistant_function? (but that's a software decision I'm not qualified to make)

gante commented 11 months ago

@apoorvumang #27750 has the new assisted_decoding structure. It is still subject to review, but you can now have a concrete idea of what I had in mind :)

Adding your technique on top of it should be straightforward!

apoorvumang commented 11 months ago

Thanks @gante ! Will look into the the refactored code now. I think I should be able to get something running by tonight (IST)

apoorvumang commented 11 months ago

I have made a working implementation here, based off of #27750 : https://github.com/apoorvumang/transformers/tree/prompt_lookup_decoding . Should I start a PR with it?

apoorvumang commented 11 months ago

Also, if you suggest any benchmarks/benchmarking code, I can help with that. I have access to A100 40GB GPU and M1 Max 32GB @gante

gante commented 11 months ago

@apoorvumang Yes, open a PR! I can add a few suggestions even before #27750 is merged :)

My advice for benchmarks would be the following: users love it when a certain method works well with little to no hyperparameters. At the moment, I see two hyperparameters -- prompt_lookup_num_tokens and prompt_lookup_max_matching_ngram. I'd run a few benchmarks over a few datasets changing these hyperparameters to find whether we can: a) get away with only one hyperparameter OR b) set an update heuristic that gets the best hyperparameters for the input at hand (through the update_candidate_strategy method)

If you find a way to make a) or b) work, the technique would become more user-friendly, and thus with a higher chance of being correctly used. For us, transformers maintainers, having fewer flags is also great!

After we settle on a final implementation, I can validate the benchmarks on different devices (e.g. a T4, a 3090, ...). Given the simplicity of the technique, I suspect the results will be mostly hardware agnostic on GPU :)

apoorvumang commented 11 months ago

Started PR here: https://github.com/huggingface/transformers/pull/27775/commits . Please do leave suggestions @gante

I will start some benchmarking on my side to find optimal hyperparameters (or update schedules). Maybe both of these can be best tuned using just a default value + update schedule, and if user wants to really change default value they can go instantiate and provide a PromptLookupCandidateGenerator with new params.

Will get back once I start some tests. I will be trying on some standard summarization, QA and maybe look for a code editing sort of dataset.

apoorvumang commented 11 months ago

image

There is significant difference between greedy and sampling when summarizing, but there are still gains. Proper analysis of the phenomenon would be a paper-worthy effort probably.

I will try to run a similar thing for code editing as well. If you think there's something I could try pls let me know.

One question @gante : Is the most popular method greedy or sampling (I would assume greedy since its the default, but I know sampling is better for quality)? If I could optimize for only one of these, which one should be the 'default'?

0xdevalias commented 11 months ago

If I could optimize for only one of these, which one should be the 'default'?

Naive question/input here.. but assuming you can figure the optimisations, and they don't apply equally to both, would it be possible to have 2 settings for it? One when used with greedy and one when used with sampling? Even if that's handled automagically under the hood (or even presumably if it's exposed to users, it would be simpler than having to know the exact hyperparameters to tune?)

apoorvumang commented 11 months ago

Thanks! Yes it can ofc - _get_candidate_generator has access to generation_config, which can be passed on here to check for stuff like this.

Any other thoughts/ideas @0xdevalias ?

0xdevalias commented 11 months ago

Thanks! Yes it can ofc

@apoorvumang Awesome :)

Any other thoughts/ideas?

@apoorvumang None at this stage; was more of a 'drive by random brain spark' type moment :)

gante commented 11 months ago

@apoorvumang @0xdevalias the preliminary results seem to point out that there is no obvious parameterization 🤔 Let's wait to see the results for coding!

Regarding sampling vs greedy: greedy is the default for legacy reasons, sampling is by far the most popular with chat LLMs :) tasks like summarization, translation, or automatic speech recognition tend to use greedy decoding or beam search, though.

Finally, regarding default values: we'll have to default the values to None, so we can detect whether the user wants to use it or not. We have a few default values for legacy reasons, but the defaults should be set at a model level (with the generation_config.json). This does not prevent us, however, from suggesting values in the parameters' docstring 🤗

apoorvumang commented 11 months ago

Here's using mt-bench, only 2nd turn code

image
apoorvumang commented 11 months ago

All 80 samples from mt-bench, 2nd turn only.

image
keyboardAnt commented 11 months ago

All 80 samples from mt-bench, 2nd turn only. image

Hi @apoorvumang – Thanks for sharing your great work!

Two quick questions:

  1. What temperature did you use in "Sampling baseline" and "Sampling PLD"?
  2. How should we interpret the black-colored lines that go below 0? (What is their minimal tokens per second rate?)
gante commented 11 months ago

@keyboardAnt the error bars are usually the standard deviation of the measurement, which is a centered (and symmetric) moment -- it does not denote the minimum/maximum of a measurement, nor a range between percentiles.

As such, I'm reading it as a long-tailed distribution. Some speedups are huge (e.g. 5x), while most are moderate (e.g. 1.5x)

apoorvumang commented 11 months ago

Hi @keyboardAnt , thank you!

  1. Default temperature, so probably 1.0
  2. As @gante said, the black coloured lines are standard deviation, not min or max. I didn't save the exact data for these so can't share that. But for places where it seems to be less than 0, its probably because of very high variance in speedups (1x to 10x).

Here's an example of this phenomenon, courtesy ChatGPT

image

PS: Sorry for the delay in working on this PR - I will try to work on it this weekend

keyboardAnt commented 11 months ago

@gante, @apoorvumang, yes. Because of the high variance, we better consider the minimal tokens/sec rate. This could ensure the long tail is one-sided. Otherwise, it might suggest a slowdown.

apoorvumang commented 11 months ago

@keyboardAnt Could you please expand on what you mean? Like we should look for configs with a good lower bound for tokens/sec rather than a good average?

keyboardAnt commented 11 months ago

@apoorvumang, my suggestion is to measure speedup. That is

speedup := (The ratio of tokens per second with PLD) / (The ratio of tokens per second without PLD)

where with-PLD and without-PLD share the same variables (e.g., prompt, target model, GPU device). We want to show that speedup >> 1 in most cases, and to rule out the possibility that speedup < 1 (i.e., a slowdown). The visualizations you shared do not rule out the possibility that speedup < 1.

We must measure speedup in varied configurations so we can better understand it. Each configuration has a unique prompt, target model, or (max_matching_ngram, num_token_output) hyperparameter. Visualizing the distribution of speedup and calculating its harmonic mean can help.

github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

keyboardAnt commented 5 months ago

@apoorvumang, my suggestion is to measure speedup. That is


speedup := (The ratio of tokens per second with PLD) / (The ratio of tokens per second without PLD)

where with-PLD and without-PLD share the same variables (e.g., prompt, target model, GPU device). We want to show that speedup >> 1 in most cases, and to rule out the possibility that speedup < 1 (i.e., a slowdown). The visualizations you shared do not rule out the possibility that speedup < 1.

We must measure speedup in varied configurations so we can better understand it.

We recently released this preprint that covers (also) the question of slowdowns: https://arxiv.org/pdf/2405.14105

Our experiments show that slowdowns exist in practice (for example, if PLD is too slow or inaccurate). We also propose a novel algorithm for running PLD (or any other drafters) in parallel so that there are no slowdowns.