Closed melvinebenezer closed 6 months ago
for an input of [1, 32] dummy tokens input is chunked by 2 making it [1,16] for each GPU
rank=0 x.shape=torch.Size([1, 16]) rank=1 x.shape=torch.Size([1, 16]) next_token_logits torch.Size([16, 32000]) next_token_logits torch.Size([16, 32000]) -- gather sample_indices shape : torch.Size([16]) Next probable Sampled_Tokens : torch.Size([16])
rank=0 x.shape=torch.Size([1, 16]) rank=1 x.shape=torch.Size([1, 16]) next_token_logits torch.Size([16, 32000]) next_token_logits torch.Size([16, 32000]) -- gather probabilities : torch.Size([16, 32000]) topk_vals: torch.Size([16, 5]), topk_indices : torch.Size([16, 5]) topk_vals after: torch.Size([16, 5]) Next probable Sampled_Tokens : torch.Size([16])
rank=0 x.shape=torch.Size([1, 16]) rank=1 x.shape=torch.Size([1, 16]) next_token_logits torch.Size([16, 32000]) next_token_logits torch.Size([16, 32000]) -- gather Next probable Sampled_Tokens : torch.Size([16])
current understanding. ... the sampled logits are not clear
The sampling is done, can moved the logit generation to decoding.py to avoid other's changes
Sampling of logits in Ring Attention