The derivatives of nn.Linear assume the layer input to have shape [batch_size, in_features], and the output to have shape [batch_size, out_features]. However, in general inputs of shape [batch_size, *, in_features] with an arbitrary number of free axes * are allowed, see the doc.
This PR adds support for additional axes to the LinearDerivatives in the core. Support for extensions will be provided through a separate PR.
The derivatives of
nn.Linear
assume the layer input to have shape[batch_size, in_features]
, and the output to have shape[batch_size, out_features]
. However, in general inputs of shape[batch_size, *, in_features]
with an arbitrary number of free axes*
are allowed, see the doc.This PR adds support for additional axes to the
LinearDerivatives
in thecore
. Support for extensions will be provided through a separate PR.Resolves the first part of https://github.com/fKunstner/backpack-discuss/issues/99.