Closed JRopes closed 3 months ago
Hi Jakob,
we currently have no tutorial that explicitly explains how to support a new loss for DiagGGN
. The best way to get started is to understand how BackPACK implements second-order extensions by looking at this tutorial which at the time of writing is still on the development
branch and has not been merged into master yet. I believe it has sufficient details to understand what object needs to be created by the module extension you want to write for your loss function.
The implementation would be similar in that you have to write a module extension which implements a backpropagate
function and then register it inside BackPACK so it knows to call it whenever it encounters a soft dice loss layer.
Best, Felix
Thank you @f-dangel for the hint! I am still a bit confused. I see that other losses have an implementation in the CORE directory to have a class such as CrossEntropyLossDerivatives. Would I not also have to implement this equivalent? Thank you very much for your help!
Hi Jakob,
you are right that BackPACK's supported loss functions call out to functionality in the core
directory in their backpropagate
function. The core
directory abstracts the autodiff functionality required by all the extensions and is helpful to avoid duplicating the implementation of extensions that use the same AD functionality.
You don't have to touch core
to get your own custom loss function running with a single extension.
Thank you @f-dangel ! I think I have this correct for the most part now. I am confused about one more small thing. I keep getting the following AssertionError:
AssertionError: BackPACK extension expects a backpropagation quantity but it is None. Module: SoftDiceLossV2(), Extension: <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x14f9d3b00410>.
Since my loss is the first step of the chain rule in the backprop, to my understanding it does not require the bpquantities as I am simply computing the GGN of the loss w.r.t. the model output. Is this assumption correct? If yes, how do I let backpack know that this extension does not require bpquantities?
Looks like you almost have it.
The error is caused by this line, because BackPACK does not recognize your SoftDiceLossV2
as loss function. BackPACK currently uses this function to detect whether a module is a loss (basically checking if your module is a subclass of torch.nn.modules.loss._Loss
). You could either try to make SoftDiceLossV2
a child class of _Loss
, or patch BackPACK's is_loss
function.
Okay this now works @f-dangel and I cannot tell you enough how much I appreciate your help! However, I now have a weird error that did not occur when I did this with your already supported CrossEntropyLoss. I will give you some context on this:
I am working on an Img2Img Problem where the input is of dimension [1,1,64,64]. For the final 2 layers I use a ConvTranspose2d (output of dim [1,1,64,64]) followed by a Flatten layer, making the final output [1,1,4096]. The loss is then calculated using my SoftDiceLoss, for which I have implemented the extension.
During the computation of the DiagGGN_Exact, my loss extension passes the bp_quantity of shape [1,1,4096] to the Flatten backprop_extension, this then passes the bp_quantity of shape [1,1,64,64] to the ConvTranspose2d backpro_extension. Now I get the error:
Traceback (most recent call last):
File "/homes/jropers/Code/confidentmedsam/LaplaceUNet/Laplace_UNet_ScaleUp64_V2.py", line 376, in <module>
diag_Sigma, unfold_transpose, model = predict(experiment, train_loader, test_loader, model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homes/jropers/Code/confidentmedsam/LaplaceUNet/Laplace_UNet_ScaleUp64_V2.py", line 153, in predict
loss.backward()
File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/torch/utils/hooks.py", line 138, in hook
out = hook(self.module, res, self.grad_outputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/backpack/__init__.py", line 209, in hook_run_extensions
backpack_extension(module, g_inp, g_out)
File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/backpack/extensions/backprop_extension.py", line 128, in __call__
module_extension(self, module, g_inp, g_out)
File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/backpack/extensions/module_extension.py", line 117, in __call__
extValue = extFunc(extension, module, g_inp, g_out, bp_quantity)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/backpack/extensions/secondorder/diag_ggn/convtransposend.py", line 12, in weight
return convUtils.extract_weight_diagonal(module, X, backproped, sum_batch=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/backpack/utils/conv_transpose.py", line 96, in extract_weight_diagonal
JS = einsum("ngckx,vngox->vngcok", unfolded_input, S)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homes/jropers/anaconda3/envs/ConfidentSAM_CUDA_11-7/lib/python3.11/site-packages/torch/functional.py", line 380, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: einsum(): subscript x has size 64 for operand 1 which does not broadcast with previously seen size 4096
The bp_quantities seem to pass all the shape checks. Also, this model architecture worked with the CrossEntropyLoss. When analyzing the passing of the bp_quantities with your extension of the CE loss, the ConvTranspose2d layer seems to be just fine handling the new shape of 64 by 64 from the Flatten backprop. Do you have any idea why I suddenly run into this error?
Again, thank you so much for your help!
I fixed this! I realized my output dimension for the bp_quantity in my loss was incorrect! Thank you for all the help @f-dangel !
Great! Let me know if there's a way to improve the docs to further simplify this.
Best, Felix
Hello! I am currently trying to extend a soft dice loss so that I can use the Daig GGN approximation with it. Could anyone maybe summarize the steps I would have to go through to make this possible? I am a little bit lost in this library. Any help is appreciated!