tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
204 stars 20 forks source link

How is it possible to handle a model that has a BatchNorm layer using the PMatEKFAC representation to get the FIM ? #19

Closed johnrachwan123 closed 3 years ago

johnrachwan123 commented 3 years ago

Currently, I get a not Implemented Exception.

tfjgeorge commented 3 years ago

Hi, you cannot directly use KFAC/EKFAC for batch norm. These approximations make no sense for batch norm layers. Depending on what you want to do, you can instead use a PMatBlockDiag representation for batch norm layers, since they have much less parameters, and still use EKFAC for linear or conv layers. If you tell me in what context you are using NNGeometry, I can guide you how to do this.

On Sat, Apr 17, 2021 at 4:34 PM John Rachwan @.***> wrote:

Currently, I get a not Implemented Exception.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tfjgeorge/nngeometry/issues/19, or unsubscribe https://github.com/notifications/unsubscribe-auth/AALTMWMS4L5UYR4AAXPSIZ3TJHWHFANCNFSM43DN4G4Q .

johnrachwan123 commented 3 years ago

I have a ResNet18 model that I would like to get the trace of the FIM for. However this model has many batchnorm layers and using the standard PMatDiag runs my 8GB CUDA GPU out of memory. How would you recommend I solve this memory issue?

tfjgeorge commented 3 years ago

In that case I suggest you reduce the batch size of the data loader you use for computing the FIM.

On Mon, Apr 19, 2021, 05:54 John Rachwan @.***> wrote:

I have a ResNet18 model that I would like to get the trace of the FIM for. However this model has many batchnorm layers and using the standard PMatDiag runs my 8GB CUDA GPU out of memory. How would you recommend I solve this memory issue?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tfjgeorge/nngeometry/issues/19#issuecomment-822338654, or unsubscribe https://github.com/notifications/unsubscribe-auth/AALTMWLGSYY3Q6N7VCVJGEDTJP4VPANCNFSM43DN4G4Q .

johnrachwan123 commented 3 years ago

Is there an easy way to modify the code in order to approximate the linear and Conv layers using PMatEKFAC and the Batchnorm layers using PMatBlockDiag ?

johnrachwan123 commented 3 years ago

I think I figured it out, but I have a small additional question. How would you recommend is the fastest way to get the trace of the FIM. I am only interested in the trace.

tfjgeorge commented 3 years ago

There is an example of PMatBlockDiag for Batch norm layers and KFAC for other layers here: https://github.com/tfjgeorge/nngeometry/blob/master/examples/FIM%20for%20EWC.ipynb (scroll down to "KFAC and Batch norm layers")

However, if you only need the trace, I suggest you either use a PMatDiag or PMatImplicit representation. In either case, if you get a memory error, just reduce the bach size.

On Wed, Apr 21, 2021 at 2:21 PM John Rachwan @.***> wrote:

I think I figured it out, but I have a small additional question. How would you recommend is the fastest way to get the trace of the FIM. I am only interested in the trace.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tfjgeorge/nngeometry/issues/19#issuecomment-824264593, or unsubscribe https://github.com/notifications/unsubscribe-auth/AALTMWJ3LTCTX5KXSUHWB43TJ4JTXANCNFSM43DN4G4Q .

johnrachwan123 commented 3 years ago

Yes, I thought of using those but the issue is it becomes extremely slow. My purpose for calculating the trace is to find when the forgetting phase in network training is reached (from https://arxiv.org/abs/1711.08856). But if finding this phase is very expensive then it defeats the purpose (Since I find this phase to improve something in network training)

tfjgeorge commented 3 years ago

I would'nt expect KFAC or EKFAC to be any faster than PMatImplicit or PMatDiag though.

There are however 2 things that you can do for a speedup:

Best

On Thu, Apr 22, 2021 at 8:28 AM John Rachwan @.***> wrote:

Yes, I thought of using those but the issue is it becomes extremely slow. My purpose for calculating the trace is to find when the forgetting phase in network training is reached (from https://arxiv.org/abs/1711.08856). But if finding this phase is very expensive then it defeats the purpose (Since I find this phase to improve something in network training)

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tfjgeorge/nngeometry/issues/19#issuecomment-824796151, or unsubscribe https://github.com/notifications/unsubscribe-auth/AALTMWPGEIUUKYZXOS3LVXTTKAI75ANCNFSM43DN4G4Q .

johnrachwan123 commented 3 years ago

Thanks a lot for your help! I actually tried using only a subset of the training set but it seemed like the value of the trace become a lot bigger so I might have made some small mistake. All I did was break the loop that was going through the data loader and adjust the n_examples to be the same as the number of loops performed before the break. Is there anything else I should do ?

tfjgeorge commented 3 years ago

The easiest way of using a subset without having to hack the internals of NNGeometry would be to create another DataLoader from your training set by using torch.utils.data.Subset

On Thu, Apr 22, 2021 at 9:02 AM John Rachwan @.***> wrote:

Thanks a lot for your help! I actually tried using only a subset of the training set but it seemed like the value of the trace become a lot bigger so I might have made some small mistake. All I did was break the loop that was going through the data loader and adjust the n_examples to be the same as the number of loops performed before the break. Is there anything else I should do ?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tfjgeorge/nngeometry/issues/19#issuecomment-824821626, or unsubscribe https://github.com/notifications/unsubscribe-auth/AALTMWPC5UHNSRR3KPRRG5TTKAM7DANCNFSM43DN4G4Q .

johnrachwan123 commented 3 years ago

Thanks a lot for your help and thank you for this great library!