Open satnavpt opened 9 months ago
Hi, thanks for bringing up this documentation inconsistency. You are right that the docs claim too general support, although it is true that most of the second-order extensions support the mentioned layers.
If you would like to see support for a specific layer and quantity that is currently missing, please feel free to specify.
Hi, thanks for getting back quickly. I am looking to experiment with KFLR and a Conjugate Gradient optimiser (using GGNMP) on a resnet18 model. I am fine running in eval mode, but both of these extensions do not have definitions for BatchNorm and AdaptiveAvgPool right now, which throws an error. Support for these would be greatly appreciated.
I tried manually making some changes, but I see that the issue is more than just missing links to the module extensions...
Thanks for boiling things down. I think what you are requesting requires a lot of new additions to BackPACK. They are not impossible to realize, but you would be mostly on your own to realize them.
nn.Module
s. This means you will need to add support to BackPACK's custom SumModule
. For KFLR
, the required functionality is already implemented, so the only thing you would have to do is write a HBPSumModule
extension that uses the SumModuleDerivatives
from the core
package. For GGNMP
, you will also have to implement new core
functionality, namely SumModuleDerivatives._jac_mat_prod
, then write the associated module extension.AdaptiveAvgPool
, I verified that the core
functionality is already there, so all you would have to do to add support is write module extensions for GGNMP
and KFLR
that use the AdaptiveAvgPoolDerivatives
from the core
.BatchNorm2d
the situation is similar to AdaptiveAvgPool
in that most of the low-level functionality is already implemented and you have to write the corresponding module extensions. Since BatchNorm2d
also has trainable parameters, you will have to specify how to compute KFLR
for .weights
and .bias
yourself.I've made the following changes to get GGNMP working with the SumModule, but it seems my gradients are vanishing:
Accumulate for ggnmp:
def accumulate_backpropagated_quantities(self, existing: Callable, other: Callable) -> Callable:
return lambda mat: existing(mat) + other(mat)
Sum Module for GGNMP:
class GGNMPSumModule(GGNMPBase):
"""GGNMP extension for SumModule."""
def __init__(self):
"""Initialization."""
super().__init__(derivatives=SumModuleDerivatives())
SumModule._jac_mat_prod:
def _jac_mat_prod(
self,
module: SumModule,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
mat: Tensor,
subsampling: List[int] = None,
) -> Tensor:
return mat
Are you able to help?
Hi, thanks for the update. Your changes look good to me. What is the error you're currently seeing? I don't understand what you mean by 'gradients are vanishing'.
Best, Felix
I am using GGNMP alongside the implementation of a conjugate gradient optimiser provided as an example here. Just printing gradients (p.grad) at each optimisation step, I see that they become all zeros after a number of steps. It is possible that this is due to the modified resnet I am testing (disabled average pooling and BatchNorm for the time being as I was just testing to see if the summodule implementation was correct).
Thanks, Pranav
I'm not sure if debugging the correctness of GGNMP
through the CG optimizer is the most direct way.
You could try comparing the GGNMP
with BackPACK's hessianfree.ggn_vector_product_from_plist
to see if the matix-vector product with the GGN works properly. There could be another effect that is not related to GGNMP
giving you zero gradients.
Unsure if this is similar to a previous issue where there was simply a missing link, or whether there is a more fundamental reason why these, and other modules which the documentation claims are supported, are not actually supported.