Open LB-bulb opened 1 year ago
Hi, thanks for your question.
your approach to use the converter functionality seems correct, and BackPACK is capable to compute individual gradients for LSTM
layers. Do you have a small code snippet that reproduces the problem? This would be extremely helpful.
Felix
Hi, Thanks very much for your reply. And I make a small code snippet. It may be a little long, but has the same problem.
from __future__ import print_function
import torch
from torch import nn, optim, autograd
from backpack import backpack, extend
from backpack.extensions import BatchGrad
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ip_emb = torch.nn.Linear(1,8)
self.leaky_relu = torch.nn.LeakyReLU(0.1)
self.enc_lstm = torch.nn.LSTM(8,64,1)
self.decode=extend(decode(),use_converter=True) #the part I am interested in
def lossfun(self):
return nn.MSELoss(reduction = 'sum')
def forward(self,hist):
_,(enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(hist)))
out = self.decode(enc)
return out, enc
class decode(nn.Module):
def __init__(self):
super(decode, self).__init__()
self.dec_lstm = torch.nn.LSTM(64, 8 ,batch_first=True)
self.op = torch.nn.Linear(8,1)
def forward(self,enc):
h_dec, _ = self.dec_lstm(enc)
fut_pred = self.op(h_dec)
return fut_pred
net = Net()
mse_extended = extend(net.lossfun())
x=torch.zeros(15,5,1)
y=torch.zeros(1,5,1)
out,enc = net(x)
pred = net.decode(enc)
loss = mse_extended(pred, y)
with backpack(BatchGrad()):
loss.backward(
inputs=list(net.decode.parameters()), retain_graph=True, create_graph=True
)
count=0
for name, weights in net.decode.named_parameters():
count=count+1
print(count)
print('name',name)
print('weights',weights.shape)
print('weights',type(weights))
print(weights.requires_grad)
print(weights.grad_batch.shape)
By the way,the torch version is 1.9.1.I'm not sure if this has an effect.
Hi,
I was able to reproduce your problem with the snippet and looked at the code BackPACK executes by running with backpack(..., debug=True)
.
This revealed that BackPACK does not execute the BatchGrad
extension on your decoder's LSTM
layer (relevant output only):
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7f352eeceb50> on MSELoss()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7f352eeceb50> on Linear(in_features=8, out_features=1, bias=True)
The reason seems to be that by setting inputs=list(net.decode.parameters())
in the backward
call, PyTorch's autodiff won't fire BackPACK's backward hook to execute on the LSTM
but stop before. (I am not familiar with the logic of backward hook execution when inputs=...
are specified.) If I comment out the inputs=...
, BackPACK executes on the LSTM
(relevant output only, see last row):
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on MSELoss()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on Linear(in_features=8, out_features=1, bias=True)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on LSTM(64, 8, batch_first=True)
Hope this helps. Felix
@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling backpack-for-pytorch==1.3.0
(the older version).
@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling
backpack-for-pytorch==1.3.0
(the older version).
@preminstrel Hi, I am also reproducing Fishr. I use torch==1.13.1
and backpack-for-pytorch==1.3.0
, an error like
ImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad'
occured. This is because _grad_input_padding
was removed in higher version of torch. May I ask how you fix it?
@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling
backpack-for-pytorch==1.3.0
(the older version).@preminstrel Hi, I am also reproducing Fishr. I use
torch==1.13.1
andbackpack-for-pytorch==1.3.0
, an error likeImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad'
occured. This is because_grad_input_padding
was removed in higher version of torch. May I ask how you fix it?
@Chelsea-abab I used torch == 1.12.1
. Maybe you can try this torch version?
@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling
backpack-for-pytorch==1.3.0
(the older version).@preminstrel Hi, I am also reproducing Fishr. I use
torch==1.13.1
andbackpack-for-pytorch==1.3.0
, an error likeImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad'
occured. This is because_grad_input_padding
was removed in higher version of torch. May I ask how you fix it?@Chelsea-abab I used
torch == 1.12.1
. Maybe you can try this torch version?
@preminstrel Thanks! It does work!
Hi, I am also reproducing Fishr. I have the same problem with you: "AttributeError: 'Parameter' object has no attribute 'grad_batch'". I change my version of backpack-for-pytorch and torch to 1.3.0 and 1.12.1 respectively, but it still didn't work. Would you give me some kind advice? Thank you~
Hi, I am also reproducing Fishr. I have the same problem with you: "AttributeError: 'Parameter' object has no attribute 'grad_batch'". I change my version of backpack-for-pytorch and torch to 1.3.0 and 1.12.1 respectively, but it still didn't work. Would you give me some kind advice? Thank you~
Hi, I fixed this problem by using backpack==1.3.0 and torch==1.12.1. This works for me. Maybe you can try?
Hi, I meet some problems when using BackPACK with RNNs. the decode part is the last several layer of my model,and here is the code. I am interested in its individual gradients ,so extend this modules.
and I refer to your documentation using the 'use_converter=True'. And I use nn.MSEloss for the loss function. the part of using backpack is as following:
but when I want to check whether it is working by this code
there always has error.
By the way,I would like to ask if there is a definition of the grad_batch in backpack。