f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
558 stars 55 forks source link

AttributeError: 'Parameter' object has no attribute 'grad_batch' #270

Open LB-bulb opened 1 year ago

LB-bulb commented 1 year ago

Hi, I meet some problems when using BackPACK with RNNs. the decode part is the last several layer of my model,and here is the code. I am interested in its individual gradients ,so extend this modules.

class decode(nn.Module):
    def __init__(self,args):
        super(decode, self).__init__()
        self.args = args
        self.soc_embedding_size = (((args['grid_size'][0]-4)+1)//2)*self.conv_3x1_depth*5
        self.dyn_embedding_size = args['dyn_embedding_size']
        self.decoder_size =args['decoder_size']

        self.dec_lstm = torch.nn.LSTM(self.soc_embedding_size + self.dyn_embedding_size, self.decoder_size ,batch_first=True)
        self.op = torch.nn.Linear(self.decoder_size,2)

    def forward(self,enc):

        h_dec, _ = self.dec_lstm(enc)

        fut_pred = self.op(h_dec)

        return fut_pred

and I refer to your documentation using the 'use_converter=True'. And I use nn.MSEloss for the loss function. the part of using backpack is as following:

with backpack(BatchGrad()):
         loss.backward(
                inputs=list(decode.parameters(),retain_graph=True,create_graph=True)
          )

but when I want to check whether it is working by this code

for name, weights in decode.named_parameters():
    count=count+1
    print(count)
    print('name',name)
    print('weights',weights.shape)
    print(weights.requires_grad)
    print(weights.grad_batch.shape)

there always has error.

Traceback (most recent call last):
  File "train.py", line 342, in <module>
    env["grads_variance"] = compute_grads_variance(encodenum, fut, net.decode,op_mask)
  File "train.py", line 85, in compute_grads_variance
    print(weights.grad_batch.shape)
AttributeError: 'Parameter' object has no attribute 'grad_batch'

By the way,I would like to ask if there is a definition of the grad_batch in backpack。

f-dangel commented 1 year ago

Hi, thanks for your question.

your approach to use the converter functionality seems correct, and BackPACK is capable to compute individual gradients for LSTM layers. Do you have a small code snippet that reproduces the problem? This would be extremely helpful.

Felix

LB-bulb commented 1 year ago

Hi, Thanks very much for your reply. And I make a small code snippet. It may be a little long, but has the same problem.

from __future__ import print_function
import torch
from torch import nn, optim, autograd
from backpack import backpack, extend
from backpack.extensions import BatchGrad

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.ip_emb = torch.nn.Linear(1,8)
        self.leaky_relu = torch.nn.LeakyReLU(0.1)
        self.enc_lstm = torch.nn.LSTM(8,64,1)
        self.decode=extend(decode(),use_converter=True) #the part I am interested in

    def lossfun(self):
        return nn.MSELoss(reduction = 'sum')

    def forward(self,hist):
        _,(enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(hist)))
        out = self.decode(enc)
        return out, enc

class decode(nn.Module):
    def __init__(self):
        super(decode, self).__init__()
        self.dec_lstm = torch.nn.LSTM(64, 8 ,batch_first=True)
        self.op = torch.nn.Linear(8,1)

    def forward(self,enc):
        h_dec, _ = self.dec_lstm(enc)
        fut_pred = self.op(h_dec)
        return fut_pred

net = Net()
mse_extended = extend(net.lossfun())
x=torch.zeros(15,5,1)
y=torch.zeros(1,5,1)
out,enc = net(x)

pred = net.decode(enc)
loss = mse_extended(pred, y)

with backpack(BatchGrad()):
    loss.backward(
        inputs=list(net.decode.parameters()), retain_graph=True, create_graph=True
    )
count=0
for name, weights in net.decode.named_parameters():
    count=count+1
    print(count)
    print('name',name)
    print('weights',weights.shape)
    print('weights',type(weights))
    print(weights.requires_grad)
    print(weights.grad_batch.shape)

By the way,the torch version is 1.9.1.I'm not sure if this has an effect.

f-dangel commented 1 year ago

Hi,

I was able to reproduce your problem with the snippet and looked at the code BackPACK executes by running with backpack(..., debug=True).

This revealed that BackPACK does not execute the BatchGrad extension on your decoder's LSTM layer (relevant output only):

[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7f352eeceb50> on MSELoss()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7f352eeceb50> on Linear(in_features=8, out_features=1, bias=True)

The reason seems to be that by setting inputs=list(net.decode.parameters()) in the backward call, PyTorch's autodiff won't fire BackPACK's backward hook to execute on the LSTM but stop before. (I am not familiar with the logic of backward hook execution when inputs=... are specified.) If I comment out the inputs=..., BackPACK executes on the LSTM (relevant output only, see last row):

[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on MSELoss()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on Linear(in_features=8, out_features=1, bias=True)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on LSTM(64, 8, batch_first=True)

Hope this helps. Felix

preminstrel commented 1 year ago

@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling backpack-for-pytorch==1.3.0 (the older version).

Chelsea-abab commented 1 year ago

@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling backpack-for-pytorch==1.3.0 (the older version).

@preminstrel Hi, I am also reproducing Fishr. I use torch==1.13.1 and backpack-for-pytorch==1.3.0, an error like ImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad' occured. This is because _grad_input_padding was removed in higher version of torch. May I ask how you fix it?

preminstrel commented 1 year ago

@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling backpack-for-pytorch==1.3.0 (the older version).

@preminstrel Hi, I am also reproducing Fishr. I use torch==1.13.1 and backpack-for-pytorch==1.3.0, an error like ImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad' occured. This is because _grad_input_padding was removed in higher version of torch. May I ask how you fix it?

@Chelsea-abab I used torch == 1.12.1. Maybe you can try this torch version?

Chelsea-abab commented 1 year ago

@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling backpack-for-pytorch==1.3.0 (the older version).

@preminstrel Hi, I am also reproducing Fishr. I use torch==1.13.1 and backpack-for-pytorch==1.3.0, an error like ImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad' occured. This is because _grad_input_padding was removed in higher version of torch. May I ask how you fix it?

@Chelsea-abab I used torch == 1.12.1. Maybe you can try this torch version?

@preminstrel Thanks! It does work!

Cccjl219 commented 3 months ago

Hi, I am also reproducing Fishr. I have the same problem with you: "AttributeError: 'Parameter' object has no attribute 'grad_batch'". I change my version of backpack-for-pytorch and torch to 1.3.0 and 1.12.1 respectively, but it still didn't work. Would you give me some kind advice? Thank you~

Chelsea-abab commented 3 months ago

Hi, I am also reproducing Fishr. I have the same problem with you: "AttributeError: 'Parameter' object has no attribute 'grad_batch'". I change my version of backpack-for-pytorch and torch to 1.3.0 and 1.12.1 respectively, but it still didn't work. Would you give me some kind advice? Thank you~

Hi, I fixed this problem by using backpack==1.3.0 and torch==1.12.1. This works for me. Maybe you can try?