Closed schaefertim closed 2 years ago
Question: Currently, BackPACK assumes batch_axis=0
for any Embedding
layer. For cases with batch-second, this will lead to wrong results. How can this be addressed?
Question: Currently, BackPACK assumes
batch_axis=0
for anyEmbedding
layer. For cases with batch-second, this will lead to wrong results. How can this be addressed?
I don't think there is a way to determine this from within BackPACK. The only ways I can think of include graph analysis which is way too much effort.
Question: Currently, BackPACK assumes
batch_axis=0
for anyEmbedding
layer. For cases with batch-second, this will lead to wrong results. How can this be addressed?I don't think there is a way to determine this from within BackPACK. The only ways I can think of include graph analysis which is way too much effort.
That is a problem for all extensions that compute per-sample quantities or support mini-batch sub-sampling. I am currently not sure how this should best be resolved.
As discussed offline, for now Embedding
is assuming batch_axis=0
. Mentioning https://github.com/fKunstner/backpack-discuss/issues/117 which will make this assumption more rigorous across the package.
@schaefertim Could you also add the SqrtGGN
extension for Embedding
+ tests?
Embedding:
Derivatives: only derivative wrt weight, since derivative wrt input not well-defined
Extensions: first order, DiagGGN
tests: derivatives, extensions
[x] raise similar Error:
RuntimeError: only Tensors of floating point and complex dtype can require gradients