karpathy / ng-video-lecture

3.57k stars 930 forks source link

"index out of range" error when using a different embedding dimension than vocab_size #20

Open zhoupingjay opened 1 year ago

zhoupingjay commented 1 year ago

https://github.com/karpathy/ng-video-lecture/blob/52201428ed7b46804849dea0b3ccf0de9df1a5c3/bigram.py#L66

If I change the 2nd parameter (dimension of the embedding) to something different than vocab_size (e.g. 128), I got "index out of range error" in generate().

To replicate the error, just change this line in the notebook:

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, 128)    # <-- change dimension to 128

And then rerun the cell:

torch.Size([32, 128])
tensor(5.2106, grad_fn=<NllLossBackward0>)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
[<ipython-input-14-58747f1080e0>](https://localhost:8080/#) in <cell line: 48>()
     46 print(loss)
     47 
---> 48 print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

5 frames
[<ipython-input-14-58747f1080e0>](https://localhost:8080/#) in generate(self, idx, max_new_tokens)
     30         for _ in range(max_new_tokens):
     31             # get the predictions
---> 32             logits, loss = self(idx)
     33             # focus only on the last time step
     34             logits = logits[:, -1, :] # becomes (B, C)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

[<ipython-input-14-58747f1080e0>](https://localhost:8080/#) in forward(self, idx, targets)
     14 
     15         # idx and targets are both (B,T) tensor of integers
---> 16         logits = self.token_embedding_table(idx) # (B,T,C)
     17 
     18         if targets is None:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py](https://localhost:8080/#) in forward(self, input)
    160 
    161     def forward(self, input: Tensor) -> Tensor:
--> 162         return F.embedding(
    163             input, self.weight, self.padding_idx, self.max_norm,
    164             self.norm_type, self.scale_grad_by_freq, self.sparse)

[/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2208         # remove once script supports set_grad_enabled
   2209         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2211 
   2212 

IndexError: index out of range in self
zhoupingjay commented 1 year ago

More analysis: I think the root cause is here:

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

We use the output from Embedding directly as "logits", which implies that each dimension of the embedding is the probability of one "class". So this essentially requires that the number of dimensions to be same as the number of classes (vocab_size). If we set number of dimensions to be larger than vocab_size (e.g. 128), the next token index (idx_next) could be larger than vocab_size, resulting in "out of index" error.