f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

BackPACK with simple attention and additional layers #326

Open nhianK opened 2 months ago

nhianK commented 2 months ago

I want to use backpack for computing per-sample gradients and was trying to understand the challenges of using a custom model that uses pytorch nn layers. For example, something like this architecture: https://github.com/codingchild2424/MonaCoBERT/blob/master/src/models/monacobert.py

Some of the basic layers used for computing attention: self.query = nn.Linear(hidden_size, self.all_head_size, bias=False) # 512 -> 256 self.key = nn.Linear(hidden_size, self.all_head_size, bias=False) # 512 -> 256 self.value = nn.Linear(hidden_size, self.all_head_size, bias=False)

The model also has a trainable nn.Parameter: self.gammas = nn.Parameter(torch.zeros(self.num_attention_heads, 1, 1)) And some convolutional layers.

What could be some of the challenges I might face while using a model like that and potential solutions to them? Is LayerNorm supported yet?

fKunstner commented 2 months ago

No a clean solution, but aiming for the minimum amount of code for it to work.

For the nn.Linear and convolutions layers, telling backpack to only extend the submodules, and that would let you extract the individual gradients for those. If you just want individual gradients for those, that'd be relatively easy (go through all the leaf-level module in your network, if it's a Linear of Conv, call extend on it)

The custom parameters like self.bias in SeparableConv1D and self.gammas in MonotonicConvolutionalMultiheadAttention are more tricky. To avoid having to write a custom gradient extraction code, you could rewrite them as a nn.Linear. That'd be less efficient, might be enough for experimentation.

For example instead of

def __init__():
    ...
    self.bias = nn.parameters(...)
    ...

def forward(...):
    ...
    x += self.bias
    ...

you could do

def __init__():
    ...
    self.dummy_linear = nn.Linear(...)
    self.dummy_linear.weight = Identity
    self.dummy_linear.weight.requires_grad = False
    ...

def forward(...):
    ...
    x = self.dummy_linear(x)
    ...

LayerNorm could also be implemented through a nn.Linear and keeping the weight matrix diagonal with linear.weight.data = torch.diag(torch.diag(linear.weight.data)) inplace in the forward before doing a matmult.