tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
203 stars 20 forks source link

Information to include in the FIM() loader argument #55

Closed CharlesVhs closed 1 year ago

CharlesVhs commented 1 year ago

Hi, I'm getting started with your code and I would like to adapt it to a supervised regression ML model. I was wondering which information should be included in the FIM(model, loader, representation, n_output, variant='classif_logits', device='cpu', function=None, layer_collection=None) loader argument? Is it only the new inputs to the model, the output or both?

In your example, you are using nngeometry to do continual learning from MNIST dataset, which is a classification problem. I would like to know if your code only accept ClassIncremental() dataset or can be adapted to regression problems ? If yes, how should it be implemented?

Thank you.

tfjgeorge commented 1 year ago

Hi Charles,

FIM accepts any PyTorch DataLoader (or inherited from it) object. Only the inputs are going to be used, and the ouput can either be sampled if you use FIM_MonteCarlo, or if you instead use FIM the closed-form expressions from [Pascanu and Bengio, 2013] are used. Please have a look at the following notebook which uses a vanilla DataLoader class instead of ClassIncremental from the Continuum library: https://github.com/tfjgeorge/nngeometry-examples/blob/main/display_and_timings/Timings%20and%20display%20of%20FIM%20representations.ipynb

In order to compute a FIM for a regression problem, you need to pass 'regression' instead of 'classif_logits'.

Please do not hesitate if you have any further question.

Thomas

[Pascanu and Bengio, 2013] Revisiting Natural Gradient for Deep Networks