f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

Facing error while Using DiagHessian for torchvision.models.resnet18 #318

Closed AliGhadirii closed 10 months ago

AliGhadirii commented 10 months ago

Hi,

I want to apply the DiagHessian for my torchvision.models.resnet18 model:

image

I get this error:

RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function. As the error suggest I should clone the output of the layers or blocks in my model that contain any in-place opreation in them. In resnet18, the in-place operations are the relu functions and the addition we have in BasicBlock modules (line 102):

image

So I changed my forward function to feed the data layer by layer and clone the output at each step to resolve the issue. Here is my model with its forward function:

image

The problem still persist after this change. Can anyone help me with this? Am I cloning the outputs correctly?

f-dangel commented 10 months ago

Hi Ali,

thanks for reaching out. The error you were seeing was from an in-place activation. In-place operations in PyTorch are incompatible with backward hooks, which are used in BackPACK. We have a converter that automatically replaces them with their inplace=False equivalents and should fix this error. To use it, you need to call extend(model, use_converter=True) when extending the model.

FYI: DiagHessian is currently not supported for ResNets in BackPACK because computing the Hessian diagonal for such graphs is more challenging. So you will see a new not-supported error after applying the above fix. However, if your network has only ReLU activations (I believe this is true for ResNet18), you can use DiagGGNExact which is supported for such architectures and coincides with the Hessian diagonal.

Cheers, Felix