f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

[ADD] Embedding: derivatives, extensions, tests #216

Closed schaefertim closed 2 years ago

schaefertim commented 3 years ago

Embedding:

f-dangel commented 3 years ago

Question: Currently, BackPACK assumes batch_axis=0 for any Embedding layer. For cases with batch-second, this will lead to wrong results. How can this be addressed?

schaefertim commented 3 years ago

Question: Currently, BackPACK assumes batch_axis=0 for any Embedding layer. For cases with batch-second, this will lead to wrong results. How can this be addressed?

I don't think there is a way to determine this from within BackPACK. The only ways I can think of include graph analysis which is way too much effort.

f-dangel commented 3 years ago

Question: Currently, BackPACK assumes batch_axis=0 for any Embedding layer. For cases with batch-second, this will lead to wrong results. How can this be addressed?

I don't think there is a way to determine this from within BackPACK. The only ways I can think of include graph analysis which is way too much effort.

That is a problem for all extensions that compute per-sample quantities or support mini-batch sub-sampling. I am currently not sure how this should best be resolved.

f-dangel commented 2 years ago

As discussed offline, for now Embedding is assuming batch_axis=0. Mentioning https://github.com/fKunstner/backpack-discuss/issues/117 which will make this assumption more rigorous across the package.

f-dangel commented 2 years ago

All good here. Waiting for coveralls to be back online (status)

f-dangel commented 2 years ago

@schaefertim Could you also add the SqrtGGN extension for Embedding + tests?