tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
203 stars 20 forks source link

Is there 'regression' option in FIM_MonteCarlo? #50

Closed zoezhou1999 closed 1 year ago

zoezhou1999 commented 2 years ago

Hi, Thank you for your great work. That is amazing! Is there 'regression' option in FIM_MonteCarlo?

tfjgeorge commented 2 years ago

Hi, that's not really a fix but I updated the docstring of the FIM_MonteCarlo which incorrectly mentionned a 'regression' variant.

I am not sure what you would expect as a Monte Carlo estimate of the FIM in the case where the output of a neural net is the mean of a gaussian distribution. Maybe by specifying a scalar variance and sample from the gaussian model? If you have any lead it should not be difficult to implement!

tfjgeorge commented 1 year ago

I am closing this as there is no straightforward way of performing a MC estimate of the FIM in the case of a gaussian model.

fmaaf commented 1 year ago

If I want to compute FIM(just the support layers) of the 3D detection model like SECOND(code:https://github.com/traveller59/second.pytorch paper:https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf), do we just need to extend the 'regression' variants by the loss function of SECOND?

tfjgeorge commented 1 year ago

Hi, I am not so sure what you are trying to achieve.

The FIM can be computed if you have a parametrized model that defines a probability distribution. From a quick glance, I don't seem to be able to find one in the Second paper, which I did not know.

In the particular context of KFAC or EKFAC, the FIM is used in natural gradient in order to speed up training. If this is what you want to achieve, you can try the FIM variant that corresponds to your task (i.e. regression or classification). Be aware that the natural gradient is very sensitive to the choice of hyperparameters (learning rate and damping).

If however your model outputs too many scalar values, you should instead use a Monte Carlo estimate as the computation time of the FIM would otherwise be prohibitive.

Another option would be to instead use the second moment of the loss gradient matrix (sometimes called the empirical Fisher).

fmaaf commented 1 year ago

Hi, actually I want to compute the FIM to replace the diagonal matrix in EWC to achieve a better performance of continual learning.

tfjgeorge commented 1 year ago

Then the same argument as above: