f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

Second order extension #305

Closed yuanjiechen closed 1 year ago

yuanjiechen commented 1 year ago

Thanks for your contribution. How to write second order module extension for hessian diagonal computation? I need to compute diagonal for self-defined nn module. Is there an example or tortoise I can reference? I can find the first order extension example only. Thanks !

f-dangel commented 1 year ago

Hi,

you're right that there is only an example for writing first-order extensions in the documentation. This is mainly for simplicity. First-order extensions only require methods that specify how information w.r.t. the parameters is extracted. Second-order extensions must also specify how information is propagated through a layer, i.e. from output to input.

The rough outline would be:

The diagonal Hessian extension is one of the more complicated ones in BackPACK. You can find a description of the method in Appendix A.3 of the paper.

It would be helpful if you could provide details how the layer you're trying to add works, e.g. by a code snippet of the forward pass. Depending on that, the .backpropagate might be more or less complicated to implement.

Best, Felix

yuanjiechen commented 1 year ago

Thanks your detailed help, I will try it!