GMvandeVen / continual-learning

PyTorch implementation of various methods for continual learning (XdG, EWC, SI, LwF, FROMP, DGR, BI-R, ER, A-GEM, iCaRL, Generative Classifier) in three different scenarios.
MIT License
1.54k stars 310 forks source link

Empirical Fisher Estimation #11

Closed tangbinh closed 4 years ago

tangbinh commented 4 years ago

It seems convenient to average the gradients over samples by calling F.nll_loss before squaring them, as we only need one backward pass. However, I feel like the diagonal of the empirical Fisher information matrix should be calculated by squaring the gradients before taking their average (as done in this Tensorflow implementation). Can you please confirm that the order doesn't matter here?

My understanding is that the expected values of the gradients are 0 (see this Wiki), so if you do averaging first, the Fisher values are very close to 0, which seems incorrect. Am I missing something here? Please let me know what you think. Thank you.

https://github.com/GMvandeVen/continual-learning/blob/ff0e03cb913ac0dea4fc59058968b1e6784decfd/continual_learner.py#L110-L125

GMvandeVen commented 4 years ago

Thanks for your comment! You’re right that the order in which the square and average operations are applied matters, exactly for the reason you point out. But the code in my repository does use the correct order, because the Fisher Information matrix is calculated with batches of size 1 (see here).

I admit this might not be the most efficient implementation (alternative suggestions are very welcome!), but the reason I chose this implementation for now is that (as far as I’m aware) in PyTorch it is currently not possible to access the gradients of individual elements in a sum.

tangbinh commented 4 years ago

I see. Estimating the Fisher information matrix with a single example seems to result in high variance (see this notebook for example). Have you seen any difference if a larger batch size is used? As you said, I don't think it's possible to access individual gradients with just one backward call.

GMvandeVen commented 4 years ago

Sorry, I should have been clearer, I realise my use of ‘batches’ is a bit confusing here. What I meant is that the backward passes(*) are done one-by-one for each sample used the calculate the Fisher Information matrix. The number of samples used to calculate the Fisher Information matrix is typically not 1 (this is set by the option --fisher-n; the default is using the full training-set).

(*) actually also the forward passes; I now realise this could be made more efficient by at least performing the forward passes with larger ‘batches’. I’ll look into that when I get some time.