Lightning-AI / litgpt

20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.
https://lightning.ai
Apache License 2.0
10.77k stars 1.07k forks source link

Nucleus (top-p) sampling #1347

Open belerico opened 7 months ago

belerico commented 7 months ago

Nucleus sampling (top-p sampling in HF) is a dynamic sampling strategy that "truncat[es] the unreliable tail of the probability distribution, sampling from the dynamic nucleus of tokens containing the vast majority of the probability mass.". It can be easily implemented in the sample method like this:

def sample(
    logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None
) -> torch.Tensor:
    logits = logits[0, -1]
    # optionally crop the logits to only the top k options
    if top_k is not None:
        v, i = torch.topk(logits, min(top_k, logits.size(-1)))
        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
        logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
    # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
    if top_p is not None:
        sorted_logits, sorted_indices = torch.sort(logits, descending=False)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
        sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
        indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
        logits = logits.masked_fill(indices_to_remove, float("-inf"))
    # optionally scale the logits and sample from a probability distribution
    if temperature > 0.0:
        probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
        return multinomial_num_samples_1(probs)
    return torch.argmax(logits, dim=-1, keepdim=True)

I can open a PR with this add if this is considered useful

rasbt commented 7 months ago

Thanks for suggesting and offering to contribute

In short, instead of selecting a hard number of samples to like in top k, it selects the number of samples such that they don't exceed a threshold p. I think this is a popular standard technique and could potentially be added as an option for litgpt chat analogous and in addition to the top_k setting. It would be a nice contribution. What do you think @awaelchli @carmocca ?

Screenshot 2024-04-24 at 8 55 01 AM
carmocca commented 7 months ago

I agree