Closed dmelcer9 closed 1 month ago
It looks like that change (https://github.com/huggingface/transformers/pull/30778) was made to save time, since if the token were sampled with correct probability then it would also have to re-sample some of the time. I think you are right, that would have to be reverted to restore correctness. Unless, "assisted decoding" is meant to behave differently from speculative sampling (https://github.com/huggingface/transformers/issues/26565#issuecomment-1766409705) and not have the correctness property (Appendix A.1). @gante
Also, I suspect (but haven't done the math) that correctness would technically only hold if the probabilities here were adjusted by temp/min_p/etc (they seem to just be temp 1 probabilities right now):
Hi @dmelcer9 @llllvvuu π
I agree that forcing greedy decoding on the assistant has to be reverted -- empirically, we can check that sampling with an assistant model has much smaller entropy than without an assistant model (and they should be the same). It was a rushed decision on my end before, I will open a PR to revert it.
However, I disagree with some of your statements on why it must be done π€ You wrote
This [forcing greedy decoding] is equivalent to setting the temperature to zero, so the output probability of the assistant model should always be 1 (for the selected token).
We have to distinguish two phases of text generation: producing the distribution for the next token and selecting the next token given the distribution. On sampling and greedy decoding, the distribution for the next token is the same. Greedy decoding does not set the temperature to 0, it simply takes the argmax of the distribution instead of sampling. The probability properties of speculative decoding at a token level do hold even with the probability distributions for the next tokens with greedy decoding. However, speculative decoding assumes we are doing sampling from the next token distribution in the assistant model, which is simply not happening at the moment and results in quasi-deterministic outputs from speculative decoding.
@gante Thanks for opening the PR. I'm not quite sure what you mean in the second part
Greedy decoding does not set the temperature to 0, it simply takes the argmax of the distribution instead of sampling.
I was definitely a bit mathematically imprecise earlier- while of course greedy decoding doesn't involve dividing the logits by 0 before the softmax, the output of the softmax function in the limit as $t \rightarrow 0$ is a one-hot distribution; i.e. the deterministic result of greedy decoding.
This was meant in the context of that, if greedy decoding is used in the assistant model, that q_i
"should be" the one-hot vector even though the following line uses the sampling probability as q_i
:
The probability properties of speculative decoding at a token level do hold even with the probability distributions for the next tokens with greedy decoding.
Not always: in the two-token vocabulary case, if the speculative model outputs [0.8, 0.2]
and the main model outputs [0.6, 0.4]
, the output will still be probabilistic but with the wrong distribution. Token 0 will always be chosen from the draft model, and probability ratio will be calculated as 0.6 / 0.8 = 0.75
. In the remaining 25% of the time, the residuals will be calculated as norm([clamp(0.6-0.8), clamp(0.4 - 0.2)]) = [0, 1]
, so token 1 will be sampled. The overall probability distribution has changed to [0.75, 0.25]
instead of being [0.6, 0.4]
(ideal), or being deterministic.
Also note that if q_i = [1, 0]
, we end up with the correct output distribution (though keeping q_i = [0.8, 0.2]
is fine if the speculative model is changed to use sampling instead of greedy).
@dmelcer9 agreed with what you wrote π€
The probability properties of speculative decoding at a token level do hold even with the probability distributions for the next tokens with greedy decoding.
I was imprecise here indeed. It is missing ",if we were sampling from that distribution to get the next token from the assistant" :)
System Info
transformers
version: 4.44.0Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
In the speculative sampling procedure: https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/generation/utils.py#L4130 The probability ratio is calculated as compared to the output probability of the assistant model.
However, the speculative model is always used greedily: https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/generation/candidate_generator.py#L158
This is equivalent to setting the temperature to zero, so the output probability of the assistant model should always be 1 (for the selected token).
As a more concrete example, if the assistant model outputs
[0.51, 0.49]
, as long as the main model outputs[x >= 0.51, y <= 0.49]
, this will lead to the first token always being sampled by the procedure.This is evident when you use a model as its own assistant, at least for the first 5 tokens from the speculative model (there is still some randomness from the extra token generated by the main model but not the assistant).
Expected behavior
Assisted decoding should use a correct sampling method.