srush / Triton-Puzzles

Puzzles for learning Triton
Apache License 2.0
925 stars 57 forks source link

Index Mixup #8

Closed SamuelGabriel closed 4 months ago

SamuelGabriel commented 4 months ago

I think you mixed up the indexes in question 3/4. Your code actually does

$$z_{j,i} = x_i + y_j\text{ for } i = 1\ldots B_0,\ j = 1\ldots B_1$$

instead of

$$z_{i, j} = x_i + y_j\text{ for } i = 1\ldots B_0,\ j = 1\ldots B_1$$

j and I are switched on z, I think.

You can see it when computing e.g.

add_vec_spec(torch.tensor([1,2,3]), torch.tensor([10,20,30]))

which returns:

tensor([[11, 12, 13],
        [21, 22, 23],
        [31, 32, 33]])

Anyways: Thanks for these puzzles :)

Edit: the same mixup seems to be in Q 5

srush commented 4 months ago

we'll take a look, thanks!

srush commented 4 months ago

Fixed now.