Open nhianK opened 2 months ago
No a clean solution, but aiming for the minimum amount of code for it to work.
For the nn.Linear
and convolutions layers, telling backpack to only extend the submodules, and that would let you extract the individual gradients for those. If you just want individual gradients for those, that'd be relatively easy (go through all the leaf-level module in your network, if it's a Linear of Conv, call extend
on it)
The custom parameters like self.bias
in SeparableConv1D
and self.gammas
in MonotonicConvolutionalMultiheadAttention
are more tricky. To avoid having to write a custom gradient extraction code, you could rewrite them as a nn.Linear
. That'd be less efficient, might be enough for experimentation.
For example instead of
def __init__():
...
self.bias = nn.parameters(...)
...
def forward(...):
...
x += self.bias
...
you could do
def __init__():
...
self.dummy_linear = nn.Linear(...)
self.dummy_linear.weight = Identity
self.dummy_linear.weight.requires_grad = False
...
def forward(...):
...
x = self.dummy_linear(x)
...
LayerNorm
could also be implemented through a nn.Linear
and keeping the weight matrix diagonal with linear.weight.data = torch.diag(torch.diag(linear.weight.data))
inplace in the forward before doing a matmult.
I want to use backpack for computing per-sample gradients and was trying to understand the challenges of using a custom model that uses pytorch nn layers. For example, something like this architecture: https://github.com/codingchild2424/MonaCoBERT/blob/master/src/models/monacobert.py
Some of the basic layers used for computing attention: self.query = nn.Linear(hidden_size, self.all_head_size, bias=False) # 512 -> 256 self.key = nn.Linear(hidden_size, self.all_head_size, bias=False) # 512 -> 256 self.value = nn.Linear(hidden_size, self.all_head_size, bias=False)
The model also has a trainable nn.Parameter: self.gammas = nn.Parameter(torch.zeros(self.num_attention_heads, 1, 1)) And some convolutional layers.
What could be some of the challenges I might face while using a model like that and potential solutions to them? Is LayerNorm supported yet?