joey00072 / Tinytorch

A really tiny autograd engine
MIT License
87 stars 2 forks source link

Small bug in tensor slicing backprop implementation #2

Closed johnma2006 closed 11 months ago

johnma2006 commented 11 months ago

Hi,

Great work on this library. I enjoyed reading through it. There is a small bug in the tensor slicing backprop implementation. Here is a minimal repro. example:

x = Parameter(Tensor([0, 1]))
x[[0, 0]].sum().backward()
print(x.grad)

which gives [1, 0] whereas [2, 0] is correct. This affects the embedding layers in GPT.

Best, John

joey00072 commented 11 months ago

Thanks @johnma2006 I didn’t check gradient when indexed with list/tensor (fixed now)

btw how did you found this?

johnma2006 commented 11 months ago

I found it from inspection, I'm doing a similar project and I remembered slicing was annoying to get right so I paid a closer attention to how you were doing it.

Numpy/torch slicing seems not consistent with one another, for instance here is an example (try it in numpy and pytorch):

embedding = Parameter(rand(50257, 512))
token_batch = [[0, 1], [3, 4]]

tok_emb = embedding[token_batch]
print(tok_emb.shape)   # this is (2, 2, 512) in numpy, (2,) in torch

tok_emb.sum().backward()
print(embedding.grad.shape)