ltatzel / PyTorchHessianFree

PyTorch implementation of the Hessian-free optimizer
BSD 3-Clause "New" or "Revised" License
30 stars 6 forks source link

How to plug the Hessian matrix into the loss function in pytorch? #3

Closed 2020213484 closed 1 year ago

2020213484 commented 1 year ago

In the continual learning method EWC, we need to compute the Fisher information matrix in the loss function, which is the Hessian matrix of log(p|x). The question is how to use the code you provided to plug the Hessian matrix into the loss function?

ltatzel commented 1 year ago

Sorry for the late response - I was on vacation ;)

I'm not sure if I understand the problem. If you want to use the Hessian of the loss function, you can use the "hessian" option here. The FIM is equivalent to the GGN in many cases (including softmax cross-entropy and square loss), see e.g. Section 9.2 in New Insights and Perspectives on the Natural Gradient Method. In this case, you can use the "ggn" option in the optimizer.