karpathy / ng-video-lecture

3.57k stars 930 forks source link

How is torch broadcasting (T, T) @ (B, T, C) ?! #34

Open whydna opened 1 year ago

whydna commented 1 year ago

At around 53:10 of the lecture, Andrej does a matrix multiplication with tensors of size (T, T) and (B, T, C). More precisely: (8, 8) @ (4, 8, 2).

Now, even after looking over PyTorch docs on broadcasting semantics, I'm surprised to see that this works - but sure enough, running the code produces an output of (4, 8, 2).

Can anyone explain how this broadcast works?

// align trailing dimensions
     8, 8
4, 8, 2

// pad missing dimensions with 1
1, 8, 8
4, 8, 2

// duplicate 1 dimensions until match
4, 8, 8
4, 8 ,2

// now what???
remorses commented 1 year ago

You can think matrix moltiplication working on the last 2 dimensions and using the first dimension only for batching

You basically can ignore the first dimension, pytorch does the matrix moltiplication for each row and concats them at the end

remorses commented 1 year ago

I think the same reasoning also applies when using images, Tensors usually have shape [Batch, Channels, Height, Width] (NCHW), you can consider the image colors as group of different images, a batching dimension

whydna commented 1 year ago

@remorses ty for the answer. Does this follow the standard broadcasting rules or is this a special case? Can't seem to find in docs.

remorses commented 1 year ago

It’s using the usual broadcasting rules, if you mean if it follows the matrix multiplication rule of having 1 dimension in common then yes, the dimension in common must the the second from the right