Closed stas-sl closed 1 year ago
@stas-sl hey Stas, i was just looking through the code today (while adding contrastive decoding) and wondering why it was so bad lol. i did a little cleanup so you can pass in any filter kwargs, also allowing for passing k
directly for topk
also think you are right about the nucleus sampling threshold being inverted, and corrected it
thanks for raising this!
Thanks for fast response and code improvements, it looks better now. Though there are still some details I'm a bit confused. What is the idea of thres
in top_k
? Should it be there at all? It looks confusing as it has same name as in top_p
, but probabilities are not thresholded there as in top_p
, but it just takes percentage/fraction of num_tokens, and more confusing that it takes 1-thres
. So if I'd like to sample using top 10% words on each step, I would need to pass thres=0.9, which seems not very convenient to me. I would propose 3 solutions:
thres
, maybe at least remove subtraction from 1, just use thres as is.thres
and k
in one argument named k
. If it is an integer take exactly k top tokens, if it is a float between 0 and 1, then take ceil(k * num_tokens)
.And considering top_a
implementation, don't you think it would be more correct as I described (taking max individually per each minibatch sample)?
@stas-sl yea, you are right about top_a
; made the fix although will probably remove it at some point in the future (it was some filtering strategy suggested in a chatroom, not in the lit)
i refactored the top_k
a bit, let me know how it looks!
Thanks, now it looks much more readable to me!
Hi and thanks for your impressive work!
I'm looking at different sampling methods in
AutoregressiveWrapper
and I have some doubts if they are implemented correctly or I'm stupid 🤪In
top_p
implementation shouldn't we remove indices wherecum_probs > thres
(instead of current1-thres
). From articleOr here meaning of
thres
is inverted compared to the article?In
top_k
sampling why not just pass k as an argument? Basically I'd like to perform top 1 (just greedy) sampling and I'm not sure how to achieve this.In
top_a
implementation it seems more reasonable to me to take max per each sample in batch (not global max over all samples), so it probably should be smth like this:limit = torch.pow(torch.max(probs, dim=-1, keepdim=True), min_p_pow) * min_p_ratio
And overall observation, as there could be different sampling functions each with their own set of arguments, it seems not very elegant to have all these arguments passed to
AutoregressiveWrapper.generate
and then having if/else inside. Why not just pass partial with bound arguments:model.generate(..., filter_logits_fn=partial(top_p, thres=0.9))
?