f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

Second Order Extensions for Custom Loss Modules #325

Closed JRopes closed 3 months ago

JRopes commented 4 months ago

Hello! I am currently trying to extend a soft dice loss so that I can use the Daig GGN approximation with it. Could anyone maybe summarize the steps I would have to go through to make this possible? I am a little bit lost in this library. Any help is appreciated!

f-dangel commented 4 months ago

Hi Jakob,

we currently have no tutorial that explicitly explains how to support a new loss for DiagGGN. The best way to get started is to understand how BackPACK implements second-order extensions by looking at this tutorial which at the time of writing is still on the development branch and has not been merged into master yet. I believe it has sufficient details to understand what object needs to be created by the module extension you want to write for your loss function.

The implementation would be similar in that you have to write a module extension which implements a backpropagate function and then register it inside BackPACK so it knows to call it whenever it encounters a soft dice loss layer.

Best, Felix

JRopes commented 3 months ago

Thank you @f-dangel for the hint! I am still a bit confused. I see that other losses have an implementation in the CORE directory to have a class such as CrossEntropyLossDerivatives. Would I not also have to implement this equivalent? Thank you very much for your help!

f-dangel commented 3 months ago

Hi Jakob, you are right that BackPACK's supported loss functions call out to functionality in the core directory in their backpropagate function. The core directory abstracts the autodiff functionality required by all the extensions and is helpful to avoid duplicating the implementation of extensions that use the same AD functionality.

You don't have to touch core to get your own custom loss function running with a single extension.

JRopes commented 3 months ago

Thank you @f-dangel ! I think I have this correct for the most part now. I am confused about one more small thing. I keep getting the following AssertionError:

AssertionError: BackPACK extension expects a backpropagation quantity but it is None. Module: SoftDiceLossV2(), Extension: <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x14f9d3b00410>.

Since my loss is the first step of the chain rule in the backprop, to my understanding it does not require the bpquantities as I am simply computing the GGN of the loss w.r.t. the model output. Is this assumption correct? If yes, how do I let backpack know that this extension does not require bpquantities?

f-dangel commented 3 months ago

Looks like you almost have it.

The error is caused by this line, because BackPACK does not recognize your SoftDiceLossV2 as loss function. BackPACK currently uses this function to detect whether a module is a loss (basically checking if your module is a subclass of torch.nn.modules.loss._Loss). You could either try to make SoftDiceLossV2 a child class of _Loss, or patch BackPACK's is_loss function.

JRopes commented 3 months ago

Okay this now works @f-dangel and I cannot tell you enough how much I appreciate your help! However, I now have a weird error that did not occur when I did this with your already supported CrossEntropyLoss. I will give you some context on this:

I am working on an Img2Img Problem where the input is of dimension [1,1,64,64]. For the final 2 layers I use a ConvTranspose2d (output of dim [1,1,64,64]) followed by a Flatten layer, making the final output [1,1,4096]. The loss is then calculated using my SoftDiceLoss, for which I have implemented the extension.

During the computation of the DiagGGN_Exact, my loss extension passes the bp_quantity of shape [1,1,4096] to the Flatten backprop_extension, this then passes the bp_quantity of shape [1,1,64,64] to the ConvTranspose2d backpro_extension. Now I get the error:

Traceback (most recent call last):
  File "/homes/jropers/Code/confidentmedsam/LaplaceUNet/Laplace_UNet_ScaleUp64_V2.py", line 376, in <module>
    diag_Sigma, unfold_transpose, model = predict(experiment, train_loader, test_loader, model)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/homes/jropers/Code/confidentmedsam/LaplaceUNet/Laplace_UNet_ScaleUp64_V2.py", line 153, in predict
    loss.backward()
  File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/torch/utils/hooks.py", line 138, in hook
    out = hook(self.module, res, self.grad_outputs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/backpack/__init__.py", line 209, in hook_run_extensions
    backpack_extension(module, g_inp, g_out)
  File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/backpack/extensions/backprop_extension.py", line 128, in __call__
    module_extension(self, module, g_inp, g_out)
  File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/backpack/extensions/module_extension.py", line 117, in __call__
    extValue = extFunc(extension, module, g_inp, g_out, bp_quantity)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/backpack/extensions/secondorder/diag_ggn/convtransposend.py", line 12, in weight
    return convUtils.extract_weight_diagonal(module, X, backproped, sum_batch=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/backpack/utils/conv_transpose.py", line 96, in extract_weight_diagonal
    JS = einsum("ngckx,vngox->vngcok", unfolded_input, S)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/torch/functional.py", line 380, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: einsum(): subscript x has size 64 for operand 1 which does not broadcast with previously seen size 4096

The bp_quantities seem to pass all the shape checks. Also, this model architecture worked with the CrossEntropyLoss. When analyzing the passing of the bp_quantities with your extension of the CE loss, the ConvTranspose2d layer seems to be just fine handling the new shape of 64 by 64 from the Flatten backprop. Do you have any idea why I suddenly run into this error?

Again, thank you so much for your help!

JRopes commented 3 months ago

I fixed this! I realized my output dimension for the bp_quantity in my loss was incorrect! Thank you for all the help @f-dangel !

f-dangel commented 3 months ago

Great! Let me know if there's a way to improve the docs to further simplify this.

Best, Felix