we should add top-k sampling so the model "stays on track"
` # do top-k sampling of 50 (huggingface pipeline default)
topk_probs here becomes (5, 50), topk_indices is (5, 50)
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
# select a token from the top-k probabilities
# note: multinomial does not demand the input to sum to 1
ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
`
we should add top-k sampling so the model "stays on track"
` # do top-k sampling of 50 (huggingface pipeline default)
topk_probs here becomes (5, 50), topk_indices is (5, 50)