tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
203 stars 20 forks source link

[suggestion] Ignore unsupported modules option #82

Closed fredguth closed 3 months ago

fredguth commented 3 months ago

Is it possible to just ignore not supported models (like nn.module.BatchNorm2d)? I want to use nngeometry with timm models.

tfjgeorge commented 3 months ago

Hi, thanks for reaching out

BatchNorm2d is actually supported with PMat(Block)Diagonal or PMatDense. I am guessing that you meant to use it with (E)KFAC ? The issue is that it does not really make sense to use KFAC on these layers, and it is pointless since they typically have fewer parameters than convolution/fully connected layers. In that case the simplest way of ignoring some layers is to manually create a LayerCollection, and then add relevant layers only, ignoring layers for which KFAC doesn't make sense. https://nngeometry.readthedocs.io/en/latest/api/layercollection.html

I chose to raise an Exception instead of silently failing in order to not mislead users into believing that they are computing the FIM for all layers.

In the future, when I have time, I was planning on implementing a mixed representation with KFAC for supported layers and PMatDiag or BlockDiag otherwise.

Hope this helps!

fredguth commented 3 months ago

I can use a layer_collection instead of a model? I didn't know that.
Thanks!