f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

Missing Support for BatchNorm and AdaptiveAvgPool in HBP methods (KFAC, KFRA, KFLR) and GGNMP #322

Open satnavpt opened 4 months ago

satnavpt commented 4 months ago

Unsure if this is similar to a previous issue where there was simply a missing link, or whether there is a more fundamental reason why these, and other modules which the documentation claims are supported, are not actually supported.

f-dangel commented 4 months ago

Hi, thanks for bringing up this documentation inconsistency. You are right that the docs claim too general support, although it is true that most of the second-order extensions support the mentioned layers.

If you would like to see support for a specific layer and quantity that is currently missing, please feel free to specify.

satnavpt commented 4 months ago

Hi, thanks for getting back quickly. I am looking to experiment with KFLR and a Conjugate Gradient optimiser (using GGNMP) on a resnet18 model. I am fine running in eval mode, but both of these extensions do not have definitions for BatchNorm and AdaptiveAvgPool right now, which throws an error. Support for these would be greatly appreciated.

satnavpt commented 4 months ago

I tried manually making some changes, but I see that the issue is more than just missing links to the module extensions...

f-dangel commented 4 months ago

Thanks for boiling things down. I think what you are requesting requires a lot of new additions to BackPACK. They are not impossible to realize, but you would be mostly on your own to realize them.

satnavpt commented 4 months ago

I've made the following changes to get GGNMP working with the SumModule, but it seems my gradients are vanishing:

Accumulate for ggnmp:

def accumulate_backpropagated_quantities(self, existing: Callable, other: Callable) -> Callable:
    return lambda mat: existing(mat) + other(mat)

Sum Module for GGNMP:

class GGNMPSumModule(GGNMPBase):
    """GGNMP extension for SumModule."""

    def __init__(self):
        """Initialization."""
        super().__init__(derivatives=SumModuleDerivatives())

SumModule._jac_mat_prod:

    def _jac_mat_prod(
        self,
        module: SumModule,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        subsampling: List[int] = None,
    ) -> Tensor:
        return mat

Are you able to help?

f-dangel commented 4 months ago

Hi, thanks for the update. Your changes look good to me. What is the error you're currently seeing? I don't understand what you mean by 'gradients are vanishing'.

Best, Felix

satnavpt commented 4 months ago

I am using GGNMP alongside the implementation of a conjugate gradient optimiser provided as an example here. Just printing gradients (p.grad) at each optimisation step, I see that they become all zeros after a number of steps. It is possible that this is due to the modified resnet I am testing (disabled average pooling and BatchNorm for the time being as I was just testing to see if the summodule implementation was correct).

Thanks, Pranav

f-dangel commented 4 months ago

I'm not sure if debugging the correctness of GGNMP through the CG optimizer is the most direct way. You could try comparing the GGNMP with BackPACK's hessianfree.ggn_vector_product_from_plist to see if the matix-vector product with the GGN works properly. There could be another effect that is not related to GGNMP giving you zero gradients.