f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

Hutchinson trace example does not work on ResNet18 #249

Closed jlmckins closed 2 years ago

jlmckins commented 2 years ago

I was able to run the 2nd order examples after using the. extend(model, use_converter=True) for ResNet18. However, when I try to run the Hutchinson trace example, I get the following error:

NotImplementedError: Extension saving to diag_h does not have an extension for Module <class 'backpack.custom_module.branching.SumModule'>

Is it possible to extend this module in order to be able to compute the Hutchison trace layerwise for ResNet models?

Thank you, Jeff

Here is part of the test code:

`def calc_hutchison_trace(model, criterion):

model.eval()
model = extend(model,use_converter=True)
criterion.to(device)
loss_function = extend(criterion)

# In the following, we load a batch, compute the loss and trigger the
# backward pass ``with(backpack(..))`` such that we have access to the extensions that
# we are going to use (``DiagHessian`` and ``HMP)``).
for i, data in enumerate(trainloader, 0):
    x, y = data
    x = x.to(device)
    y = y.to(device)
    break # Get 1 batch

def forward_backward_with_backpack():
    """Provide working access to BackPACK's `DiagHessian` and `HMP`."""
    loss = loss_function(model(x), y)

    with backpack(DiagHessian(),HMP()):
        # keep graph for autodiff HVPs
        loss.backward(retain_graph=True)

    return loss

# Explicit test to see if diag info is created.
loss = loss_function(model(x), y)
with backpack(DiagHessian(), BatchDiagHessian()):
    loss.backward()
for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".diag_h.shape:           ", param.diag_h.shape)
    print(".diag_h_batch.shape:     ", param.diag_h_batch.shape)

`

jlmckins commented 2 years ago

Here is the complete error message: Using pytorch model zoo pretrained model as teacher: resnet18 Traceback (most recent call last): File "train_resnet_imagenet_stepsizex4init_hiway4.py", line 1317, in <module> calc_hutchison_trace(teacher, simplecriterion) File "train_resnet_imagenet_stepsizex4init_hiway4.py", line 1261, in calc_hutchison_trace loss.backward() File "/u/jlmckins/.local/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "/u/jlmckins/.local/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag File "/u/jlmckins/.local/lib/python3.7/site-packages/torch/utils/hooks.py", line 110, in hook res = user_hook(self.module, grad_input, self.grad_outputs) File "/u/jlmckins/.local/lib/python3.7/site-packages/backpack/__init__.py", line 209, in hook_run_extensions backpack_extension(module, g_inp, g_out) File "/u/jlmckins/.local/lib/python3.7/site-packages/backpack/extensions/backprop_extension.py", line 125, in __call__ module_extension = self.__get_module_extension(module) File "/u/jlmckins/.local/lib/python3.7/site-packages/backpack/extensions/backprop_extension.py", line 100, in __get_module_extension f"Extension saving to {self.savefield} " NotImplementedError: Extension saving to diag_h does not have an extension for Module <class 'backpack.custom_module.branching.SumModule'>

f-dangel commented 2 years ago

Hi @jlmckins,

BackPACK's DiagHessian extension does not support branched computation graphs like ResNets yet. That is because computing the Hessian diagonal in the presence of arbitrary branching becomes a headache.

For your specific case, there's a way out: Since you're using a ResNet with identity skip connections and ReLUs, the Hessian trace is identical to the GGN trace. You can thus use BackPACK's DiagGGNExact that supports ResNets, and sum the diagonal to obtain the trace.

Best, Felix

jlmckins commented 2 years ago

Thank you for the fast response! I will give it a try.