Closed yuanjiechen closed 1 year ago
Hi,
you're right that there is only an example for writing first-order extensions in the documentation. This is mainly for simplicity. First-order extensions only require methods that specify how information w.r.t. the parameters is extracted. Second-order extensions must also specify how information is propagated through a layer, i.e. from output to input.
The rough outline would be:
MyModuleDiagH
for the Hessian diagonal of your layer that inherits from ModuleExtension
weight
, bias
. Then you need to implement the MyModuleDiagH.weight
and MyModuleDiagH.bias
methods similar to the first-order extension example. MyModuleDiagH.backpropagate
MyModuleDiagH
in the DiagHessian
extension so BackPACK knows which extension to call when it encounters your layer.The diagonal Hessian extension is one of the more complicated ones in BackPACK. You can find a description of the method in Appendix A.3 of the paper.
It would be helpful if you could provide details how the layer you're trying to add works, e.g. by a code snippet of the forward pass. Depending on that, the .backpropagate
might be more or less complicated to implement.
Best, Felix
Thanks your detailed help, I will try it!
Thanks for your contribution. How to write second order module extension for hessian diagonal computation? I need to compute diagonal for self-defined nn module. Is there an example or tortoise I can reference? I can find the first order extension example only. Thanks !