Closed opooladz closed 10 months ago
Hope all is well,
Just to make clear, I have the code working when I set M_inv = None in the optimizer step. My interest falls in using the preconditioner.
torch.inverse(m_inv + damping * torch.eye(*m_inv.shape).to(device))
I have not been able to get it working (trying it on MNIST) with the fisher matrix or the diagonal or the approximation of the Hessian.
And direction would be greatly appreciated. Thank you in advance.
@opooladz , sorry for the very late reply.
Unfortunately I have no experience with Fisher information matrices and I can't really remember why I ended up implementing them.
Have you read Martens' section Designing a Good Preconditioner? I remember skimming through some papers exploring preconditions for these kind of optimisation problems. It seems to be quite tricky and there are some hyper parameters to them that should require you to play a bit with the source code. It could be that there is a bug in their implementation :|
Have you had any progress on this problem?
To my understanding the Fisher information matrix just acts like the hessian matrix.
Yes I have read through some of the parts of Martens' book. If you dont give in a preconditioner to your code it wont use line 126 of hessainfree
m = torch.inverse(m_inv + damping * torch.eye(*m_inv.shape))
I have not been able to make progress. I basically just want a second order optimizer such as newton method that works on CNN and is integrated with pytorch. Have you been able to run even a vanilla second order method. Running with Conjugate-Gradient and backtracking is of course a plus.
I was hoping to get an example of a simple train/test example using the hessian free optimizer on mnist. I want to eventually actually later try a hessian for the levenberg-marquardt optimization rule. But for now I am trying to work with the EFM. I am also for now using the Hessian vector product instead of the GGN matrix vector product. Below is my equivalent to the hf_test file. I also needed to change the code for the EFM slightly.
When I run the following I either get a loss that explodes and then becomes nan or told that the matrix is not invertible. But in both cases the loss is exploding.
I define the closure inside the for loop so as to not track the rest of the changes in the other file.
I eventually want to get this working on a CNN as well. So if I can get help on either/both that would be amazing.
Thank you in advance.