Modalities / modalities

Modalities, a PyTorch-native framework for distributed and reproducible foundation model training.
MIT License
59 stars 5 forks source link

Add topk sampling for generation #146

Open fromm-m opened 3 months ago

fromm-m commented 3 months ago

we should add top-k sampling so the model "stays on track"

` # do top-k sampling of 50 (huggingface pipeline default)

topk_probs here becomes (5, 50), topk_indices is (5, 50)

            topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
            # select a token from the top-k probabilities
            # note: multinomial does not demand the input to sum to 1
            ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
          `