f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

[core] Support additional dimensions in input to `Linear` #185

Closed f-dangel closed 3 years ago

f-dangel commented 3 years ago

The derivatives of nn.Linear assume the layer input to have shape [batch_size, in_features], and the output to have shape [batch_size, out_features]. However, in general inputs of shape [batch_size, *, in_features] with an arbitrary number of free axes * are allowed, see the doc.

This PR adds support for additional axes to the LinearDerivatives in the core. Support for extensions will be provided through a separate PR.

Resolves the first part of https://github.com/fKunstner/backpack-discuss/issues/99.