lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

Sampling questions/(issues?) #184

Closed stas-sl closed 1 year ago

stas-sl commented 1 year ago

Hi and thanks for your impressive work!

I'm looking at different sampling methods in AutoregressiveWrapper and I have some doubts if they are implemented correctly or I'm stupid 🤪


def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# topk

def top_k(logits, thres = 0.9):
    k = ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# top_a

def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
    probs = F.softmax(logits, dim=-1)
    limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
    logits[probs < limit] = float('-inf')
    logits[probs >= limit] = 1
    return logits

In top_p implementation shouldn't we remove indices where cum_probs > thres (instead of current 1-thres). From article

Top-p sampling (or nucleus sampling) chooses from the smallest possible set of words whose cumulative probability exceeds the probability p. This way, the number of words in the set can dynamically increase and decrease according to the next word probability distribution.

Or here meaning of thres is inverted compared to the article?

In top_k sampling why not just pass k as an argument? Basically I'd like to perform top 1 (just greedy) sampling and I'm not sure how to achieve this.

In top_a implementation it seems more reasonable to me to take max per each sample in batch (not global max over all samples), so it probably should be smth like this: limit = torch.pow(torch.max(probs, dim=-1, keepdim=True), min_p_pow) * min_p_ratio

And overall observation, as there could be different sampling functions each with their own set of arguments, it seems not very elegant to have all these arguments passed to AutoregressiveWrapper.generate and then having if/else inside. Why not just pass partial with bound arguments: model.generate(..., filter_logits_fn=partial(top_p, thres=0.9))?

lucidrains commented 1 year ago

@stas-sl hey Stas, i was just looking through the code today (while adding contrastive decoding) and wondering why it was so bad lol. i did a little cleanup so you can pass in any filter kwargs, also allowing for passing k directly for topk

also think you are right about the nucleus sampling threshold being inverted, and corrected it

thanks for raising this!

stas-sl commented 1 year ago

Thanks for fast response and code improvements, it looks better now. Though there are still some details I'm a bit confused. What is the idea of thres in top_k? Should it be there at all? It looks confusing as it has same name as in top_p, but probabilities are not thresholded there as in top_p, but it just takes percentage/fraction of num_tokens, and more confusing that it takes 1-thres. So if I'd like to sample using top 10% words on each step, I would need to pass thres=0.9, which seems not very convenient to me. I would propose 3 solutions:

  1. Either ditch thres in top_k at all. It will be the simplest solution and probably what most of people expect from what top_k should be doing. If someone needs fraction they can calculate it manually.
  2. If you still want to use thres, maybe at least remove subtraction from 1, just use thres as is.
  3. Combine thres and k in one argument named k. If it is an integer take exactly k top tokens, if it is a float between 0 and 1, then take ceil(k * num_tokens).

And considering top_a implementation, don't you think it would be more correct as I described (taking max individually per each minibatch sample)?

lucidrains commented 1 year ago

@stas-sl yea, you are right about top_a; made the fix although will probably remove it at some point in the future (it was some filtering strategy suggested in a chatroom, not in the lit)

i refactored the top_k a bit, let me know how it looks!

stas-sl commented 1 year ago

Thanks, now it looks much more readable to me!