kuc2477 / pytorch-ewc

Unofficial PyTorch implementation of DeepMind's PNAS 2017 paper "Overcoming Catastrophic Forgetting"
MIT License
265 stars 45 forks source link

The size of fisher matrix is wrong #2

Closed Klitter closed 5 years ago

Klitter commented 5 years ago

Code in line 80 of model.py : fisher_diagonals = [(g ** 2).mean() for g in loglikelihood_grads] should be: fisher_diagonals = [(g ** 2).mean(dim=0) for g in loglikelihood_grads]

kuc2477 commented 5 years ago

You are right, and this is a critical bug. I fixed it and updated the change log in the README. I also checked that the experimental results are reproduced. Thanks for reporting.