Lightning-AI / litgpt

Pretrain, finetune, deploy 20+ LLMs on your own data. Uses state-of-the-art techniques: flash attention, FSDP, 4-bit, LoRA, and more.
https://lightning.ai
Apache License 2.0
6.85k stars 726 forks source link

Nucleus (top-p) sampling #1347

Open belerico opened 3 weeks ago

belerico commented 3 weeks ago

Nucleus sampling (top-p sampling in HF) is a dynamic sampling strategy that "truncat[es] the unreliable tail of the probability distribution, sampling from the dynamic nucleus of tokens containing the vast majority of the probability mass.". It can be easily implemented in the sample method like this:

def sample(
    logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None
) -> torch.Tensor:
    logits = logits[0, -1]
    # optionally crop the logits to only the top k options
    if top_k is not None:
        v, i = torch.topk(logits, min(top_k, logits.size(-1)))
        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
        logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
    # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
    if top_p is not None:
        sorted_logits, sorted_indices = torch.sort(logits, descending=False)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
        sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
        indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
        logits = logits.masked_fill(indices_to_remove, float("-inf"))
    # optionally scale the logits and sample from a probability distribution
    if temperature > 0.0:
        probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
        return multinomial_num_samples_1(probs)
    return torch.argmax(logits, dim=-1, keepdim=True)

I can open a PR with this add if this is considered useful

rasbt commented 3 weeks ago

Thanks for suggesting and offering to contribute

In short, instead of selecting a hard number of samples to like in top k, it selects the number of samples such that they don't exceed a threshold p. I think this is a popular standard technique and could potentially be added as an option for litgpt chat analogous and in addition to the top_k setting. It would be a nice contribution. What do you think @awaelchli @carmocca ?

Screenshot 2024-04-24 at 8 55 01 AM
carmocca commented 3 weeks ago

I agree