EleutherAI / gpt-neox

An implementation of model parallel autoregressive transformers on GPUs, based on the Megatron and DeepSpeed libraries
https://www.eleuther.ai/
Apache License 2.0
6.95k stars 1.02k forks source link

For nucleus sampling, top-p sampling appears to happen on the softmax-normalized top-k logits #1250

Closed j-frei closed 2 months ago

j-frei commented 4 months ago

Describe the bug This issue refers to the code line at: https://github.com/EleutherAI/gpt-neox/blob/1cee5b7c7074302de4867ad5cac3f1ea26f7a7d7/megatron/text_generation_utils.py#L100C43-L100C50

To my understanding, the top-p should be applied on the pre-top-k-filtered token probabilities. Apparently though, if top-k and top-p is enabled, the top-p part is applied based on the post-top-k-filtered logits, since an additional softmax is used here on the updated logit values.

Expected behavior Given a large top-p value and a very small top-k value (k > 1), the top-p part should have no effect.

If, contrary to my intuition, this current implementation indeed follows your intuition about the expected behavior of nucleus sampling, you can ignore this issue.

j-frei commented 4 months ago

To my understanding, the function should be rather using the input logits for top_p to determine the masked tokens.

def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
    """
    Filters the logits using top_k / top_p, filling any filtered vocab items with filter_value (defaults to -inf).

    This function has been mostly taken from huggingface conversational ai code at
    https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313

    logits: torch.Tensor -> logits of megatron model.
    top_k: integer -> integer between 0 and the models vocab size. Filters out any logits with a probability less than that of the top_kth token.
    top_p: float -> Top-p (nucleus) sampling chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability top_p.

    returns: (filtered) logits"""

    masked_logits = logits.clone()
    if top_k > 0:
        # Remove all tokens with a probability less than the
        # last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        masked_logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # convert to 1D
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            masked_logits[i][indices_to_remove] = filter_value

    return masked_logits
AI-WAIFU commented 2 months ago

This is intentional, we're mirroring the functionality documented in https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313, this is also in line with the default behavior of implementations such as llama.cpp https://github.com/ggerganov/llama.cpp/blob/38ca6f644bd48301e9caa80f9913c22e70a8fd1b/examples/server/README.md?plain=1#L370 which cascade token filters one after another with top_k coming before top_p

There's definitely a lot of room to improve the flexibility of sampling that's not currently implemented, so we're open to PRs and feature suggestions/requests for how to go about it, but at least for the default we believe this is the correct behavior.

Quentin-Anthony commented 2 months ago

Closing for now. Feel free to reopen if you'd like to discuss further!