Closed j-frei closed 2 months ago
To my understanding, the function should be rather using the input logits for top_p to determine the masked tokens.
def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""
Filters the logits using top_k / top_p, filling any filtered vocab items with filter_value (defaults to -inf).
This function has been mostly taken from huggingface conversational ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
logits: torch.Tensor -> logits of megatron model.
top_k: integer -> integer between 0 and the models vocab size. Filters out any logits with a probability less than that of the top_kth token.
top_p: float -> Top-p (nucleus) sampling chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability top_p.
returns: (filtered) logits"""
masked_logits = logits.clone()
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
masked_logits[indices_to_remove] = filter_value
if top_p > 0.0:
# convert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
masked_logits[i][indices_to_remove] = filter_value
return masked_logits
This is intentional, we're mirroring the functionality documented in https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313, this is also in line with the default behavior of implementations such as llama.cpp https://github.com/ggerganov/llama.cpp/blob/38ca6f644bd48301e9caa80f9913c22e70a8fd1b/examples/server/README.md?plain=1#L370 which cascade token filters one after another with top_k coming before top_p
There's definitely a lot of room to improve the flexibility of sampling that's not currently implemented, so we're open to PRs and feature suggestions/requests for how to go about it, but at least for the default we believe this is the correct behavior.
Closing for now. Feel free to reopen if you'd like to discuss further!
Describe the bug This issue refers to the code line at: https://github.com/EleutherAI/gpt-neox/blob/1cee5b7c7074302de4867ad5cac3f1ea26f7a7d7/megatron/text_generation_utils.py#L100C43-L100C50
To my understanding, the top-p should be applied on the pre-top-k-filtered token probabilities. Apparently though, if top-k and top-p is enabled, the top-p part is applied based on the post-top-k-filtered logits, since an additional softmax is used here on the updated logit values.
Expected behavior Given a large top-p value and a very small top-k value (k > 1), the top-p part should have no effect.
If, contrary to my intuition, this current implementation indeed follows your intuition about the expected behavior of nucleus sampling, you can ignore this issue.