tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
206 stars 20 forks source link

Adding support for Conv1d layers #73

Closed Xuzzo closed 10 months ago

Xuzzo commented 11 months ago

Hello and thanks again for your work.

Would it be possible to add support for Conv1d layers in K-FAC?

Thanks

tfjgeorge commented 10 months ago

Hi, this PR: https://github.com/tfjgeorge/nngeometry/pull/76 should do it.

I also implemented EKFAC if you are interested, it gives a more accurate FIM, at a slightly higher computational cost. It requires an extra call to M_ekfac.update_diag(loader) after your FIM is instantiated.

Xuzzo commented 10 months ago

Hi, thank you very much for the help! Closing the issue.