cuda-mode / ring-attention

ring-attention experiments
Apache License 2.0
89 stars 10 forks source link

few more versions of sampling #13

Closed melvinebenezer closed 6 months ago

melvinebenezer commented 6 months ago

Sampling of logits in Ring Attention

melvinebenezer commented 6 months ago

Inputs

for an input of [1, 32] dummy tokens input is chunked by 2 making it [1,16] for each GPU

greedy

rank=0 x.shape=torch.Size([1, 16]) rank=1 x.shape=torch.Size([1, 16]) next_token_logits torch.Size([16, 32000]) next_token_logits torch.Size([16, 32000]) -- gather sample_indices shape : torch.Size([16]) Next probable Sampled_Tokens : torch.Size([16])

top_k = 5

rank=0 x.shape=torch.Size([1, 16]) rank=1 x.shape=torch.Size([1, 16]) next_token_logits torch.Size([16, 32000]) next_token_logits torch.Size([16, 32000]) -- gather probabilities : torch.Size([16, 32000]) topk_vals: torch.Size([16, 5]), topk_indices : torch.Size([16, 5]) topk_vals after: torch.Size([16, 5]) Next probable Sampled_Tokens : torch.Size([16])

top_p

rank=0 x.shape=torch.Size([1, 16]) rank=1 x.shape=torch.Size([1, 16]) next_token_logits torch.Size([16, 32000]) next_token_logits torch.Size([16, 32000]) -- gather Next probable Sampled_Tokens : torch.Size([16])

melvinebenezer commented 6 months ago

current understanding. ... the sampled logits are not clear Screenshot 2024-03-06 at 8 44 50 PM

melvinebenezer commented 6 months ago

The sampling is done, can moved the logit generation to decoding.py to avoid other's changes