loss = -prob[torch.arange(32), Y].log().mean()
loss
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[20], line 1
----> 1 loss = -prob[torch.arange(32), Y].log().mean()
2 loss
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [32], [228146]