tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
203 stars 20 forks source link

compute FIM of partial parameters #69

Closed JiaxiangRen closed 11 months ago

JiaxiangRen commented 12 months ago

First, thanks for the amazing work! I want to compute the FIM of partial parameters which means only part of whole parameters requires gradients, is that possible?

tfjgeorge commented 12 months ago

Yes, sure!

You need to specify a LayerCollection object which represents the structure of the parameter space that you are interested in analysing with your FIM. By default, FIM helpers create a LayerCollection object that includes all parameters of a model, but you can instead instantiate a LayerCollection object manually, see for instance here: https://github.com/tfjgeorge/nngeometry/blob/78ba46c541a2656b80e348e366c36de47f54baa9/tests/tasks.py#L456 where the LayerCollection object only comprises a single layer.

Then you need to pass that LayerCollection object to FIM