huggingface / transformers

πŸ€— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.07k stars 26.81k forks source link

Speculative sampling does not maintain probability distribution of main model #32867

Closed dmelcer9 closed 1 month ago

dmelcer9 commented 2 months ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

In the speculative sampling procedure: https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/generation/utils.py#L4130 The probability ratio is calculated as compared to the output probability of the assistant model.

However, the speculative model is always used greedily: https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/generation/candidate_generator.py#L158

This is equivalent to setting the temperature to zero, so the output probability of the assistant model should always be 1 (for the selected token).

As a more concrete example, if the assistant model outputs [0.51, 0.49], as long as the main model outputs [x >= 0.51, y <= 0.49], this will lead to the first token always being sampled by the procedure.

This is evident when you use a model as its own assistant, at least for the first 5 tokens from the speculative model (there is still some randomness from the extra token generated by the main model but not the assistant).

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "openai-community/gpt2-medium"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

inputs = tokenizer("public int", return_tensors="pt")

# Greedy
# Always outputs `public int get_current_time()` (and then some)
tokenizer.decode(model.generate(**inputs, do_sample=False, max_new_tokens=25)[0])

# Sampling
# Gives different method names each time
tokenizer.decode(model.generate(**inputs, do_sample=True, max_new_tokens=25)[0])

# Should theoretically be sampling but is not
# Always outputs `public int get_current_time()`
tokenizer.decode(model.generate(**inputs, assistant_model=model, do_sample=True, max_new_tokens=25)[0])

Expected behavior

Assisted decoding should use a correct sampling method.

llllvvuu commented 2 months ago

It looks like that change (https://github.com/huggingface/transformers/pull/30778) was made to save time, since if the token were sampled with correct probability then it would also have to re-sample some of the time. I think you are right, that would have to be reverted to restore correctness. Unless, "assisted decoding" is meant to behave differently from speculative sampling (https://github.com/huggingface/transformers/issues/26565#issuecomment-1766409705) and not have the correctness property (Appendix A.1). @gante

Also, I suspect (but haven't done the math) that correctness would technically only hold if the probabilities here were adjusted by temp/min_p/etc (they seem to just be temp 1 probabilities right now):

https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/generation/utils.py#L4126-L4130

gante commented 1 month ago

Hi @dmelcer9 @llllvvuu πŸ‘‹

I agree that forcing greedy decoding on the assistant has to be reverted -- empirically, we can check that sampling with an assistant model has much smaller entropy than without an assistant model (and they should be the same). It was a rushed decision on my end before, I will open a PR to revert it.


However, I disagree with some of your statements on why it must be done πŸ€— You wrote

This [forcing greedy decoding] is equivalent to setting the temperature to zero, so the output probability of the assistant model should always be 1 (for the selected token).

We have to distinguish two phases of text generation: producing the distribution for the next token and selecting the next token given the distribution. On sampling and greedy decoding, the distribution for the next token is the same. Greedy decoding does not set the temperature to 0, it simply takes the argmax of the distribution instead of sampling. The probability properties of speculative decoding at a token level do hold even with the probability distributions for the next tokens with greedy decoding. However, speculative decoding assumes we are doing sampling from the next token distribution in the assistant model, which is simply not happening at the moment and results in quasi-deterministic outputs from speculative decoding.

dmelcer9 commented 1 month ago

@gante Thanks for opening the PR. I'm not quite sure what you mean in the second part

Greedy decoding does not set the temperature to 0, it simply takes the argmax of the distribution instead of sampling.

I was definitely a bit mathematically imprecise earlier- while of course greedy decoding doesn't involve dividing the logits by 0 before the softmax, the output of the softmax function in the limit as $t \rightarrow 0$ is a one-hot distribution; i.e. the deterministic result of greedy decoding.

This was meant in the context of that, if greedy decoding is used in the assistant model, that q_i "should be" the one-hot vector even though the following line uses the sampling probability as q_i:

https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/generation/utils.py#L4130

The probability properties of speculative decoding at a token level do hold even with the probability distributions for the next tokens with greedy decoding.

Not always: in the two-token vocabulary case, if the speculative model outputs [0.8, 0.2] and the main model outputs [0.6, 0.4], the output will still be probabilistic but with the wrong distribution. Token 0 will always be chosen from the draft model, and probability ratio will be calculated as 0.6 / 0.8 = 0.75. In the remaining 25% of the time, the residuals will be calculated as norm([clamp(0.6-0.8), clamp(0.4 - 0.2)]) = [0, 1], so token 1 will be sampled. The overall probability distribution has changed to [0.75, 0.25] instead of being [0.6, 0.4] (ideal), or being deterministic.

Also note that if q_i = [1, 0], we end up with the correct output distribution (though keeping q_i = [0.8, 0.2] is fine if the speculative model is changed to use sampling instead of greedy).

gante commented 1 month ago

@dmelcer9 agreed with what you wrote πŸ€—

The probability properties of speculative decoding at a token level do hold even with the probability distributions for the next tokens with greedy decoding.

I was imprecise here indeed. It is missing ",if we were sampling from that distribution to get the next token from the assistant" :)