fmeirinhos / pytorch-hessianfree

PyTorch implementation of Hessian Free optimisation
MIT License
43 stars 8 forks source link

Status for Pytorch 1.3 #1

Open sshkhr opened 4 years ago

sshkhr commented 4 years ago

Hi

Thanks for the excellent repo. I was wondering what's the status of this repo for Pytorch v1.3. I want to use the Newton-CG method using line search to reimplement a meta-learning paper. Will your code support parameters from conv layers and such?

Thanks

fmeirinhos commented 4 years ago

Cheers, @sshkhr

I have tested the optimiser for PyTorch 1.3 and it seemed to work fine. It just doesn't support parameter groups (so optimising for different parameters with different optimiser-hyper-parameters would not work.) It seems I have forgotten to raise an error for those cases but if you need that functionality let me know and I'll see if it can be implemented.

Best

sshkhr commented 4 years ago

Thanks for clarifying. I'm not particularly looking to use different optimiser-wise hyper parameters for the different network parameters so that should be okay for now. I'll try and see if I can get it working before this weekend and update you. Thanks again

opooladz commented 4 years ago

I am also interested in extending this for use on a CNN. @sshkhr have you had any luck? I want to run it on a CNN with this architecture for example. https://colab.research.google.com/github/rpi-techfundamentals/fall2018-materials/blob/master/10-deep-learning/04-pytorch-mnist.ipynb#scrollTo=0mB6qGuYiwnX

fmeirinhos commented 4 years ago

The CNN architecture you posted should work just fine, @opooladz. See the hf_test.py file for how to apply the optimiser to a torch.nn.Module

opooladz commented 4 years ago

Thank you for your reply @fmeirinhos . Yes, I actually followed that file and tried to extend it for CNNs to use with the Hv hessian method and fisher diag (I basically want an implementation of the Levenberg-Marquart update rule), but after a few iterations the loss simply goes to nan if I try with the inverse preconditioner and if i set M_inv = None, the loss just blows up. Is it possible to get another example file with a CNN? I can also send u my code as well.