GMvandeVen / continual-learning

PyTorch implementation of various methods for continual learning (XdG, EWC, SI, LwF, FROMP, DGR, BI-R, ER, A-GEM, iCaRL, Generative Classifier) in three different scenarios.
MIT License
1.54k stars 310 forks source link

average EWC pickle #6

Closed Johswald closed 5 years ago

Johswald commented 5 years ago

hey again, it seems that you save incorrect accuracies (averages) in the pickle. In the first list of "all_tasks" should be the accuracies tested when all tasks are trained correct? I did not check other methods but for online EWC they seem wrong whereas when training on SI it seems correct. Can you check that? Thanks!

GMvandeVen commented 5 years ago

Hey, I quickly checked but I couldn't find any mistakes and those accuracies saved into the pickle seem correct to me. I think the confusion is with how the "all_tasks" entry is organised, which isn't very intuitive. The first list of "all_tasks" actually contains the test accuracy for the first task measured after training on each of the tasks. The n-th entry of the m-th list is the test accuracy for task m after training on task n. Sorry for that weird structure. I hope this explains it, but if not please let me know and I'll have another look.

Johswald commented 5 years ago

Hey, thank you. I will check that again. Many benchmarks show the accuracy of task 1-n after training all on all tasks. To get that I have to iterate over the lists in „all_tasks“ and then get the last entry? Thank you again

Am 24.05.2019 um 10:51 schrieb GMvandeVen notifications@github.com:

Hey, I quickly checked but I couldn't find any mistakes and those accuracies saved into the pickle seem correct to me. I think the confusion is with how the "all_tasks" entry is organised, which isn't very intuitive. The first list of "all_tasks" actually contains the test accuracy for the first task measured after training on each of the tasks. The n-th entry of the m-th list is the test accuracy for task m after training on task n. Sorry for that weird structure. I hope this explains it, but if not please let me know and I'll have another look.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or mute the thread.

GMvandeVen commented 5 years ago

Yes, that's correct!

Johswald commented 5 years ago

Ok, because I was confused about the results. When I run online EWC for 100 tasks (./main.py --ewc --online --lambda=5000 --gamma=1 --experiment permMNIST --scenario task --tasks 100 --fc-units=1000 --lr=0.0001) and then extract the accuracies as you said for all tasks after learning all tasks, the accuracies for more recent tasks are low. Assuming I extracted the acc’s correctly, then new tasks are not able to get high accuracies. It looks like this

[0.9565, 0.9502, 0.9283, 0.9196, 0.8324, 0.898, 0.8769, …, 0.7797, 0.7696, 0.7805]

The last entry here is the acc of task n, directly after learning it. That would mean the regularisation hinders the network to actually learn new tasks (?), and protects the network well from forgetting the old tasks (really well). Did you experience the same? Do you suggest other hyper parameters? Many thanks!

On 24. May 2019, at 11:46, GMvandeVen notifications@github.com wrote:

Yes, that's correct!

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/GMvandeVen/continual-learning/issues/6?email_source=notifications&email_token=AGKA5UAJZUCSGBW6T2DS5PDPW62P7A5CNFSM4HPJTYB2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODWEXMNQ#issuecomment-495547958, or mute the thread https://github.com/notifications/unsubscribe-auth/AGKA5UGKHL6NR36Q4ZASHLTPW62P7ANCNFSM4HPJTYBQ.

GMvandeVen commented 5 years ago

Yes, that looks about right to me. The regularisation of EWC indeed hinders the network to learn new tasks (because it penalizes new parameter changes), and this effect gets larger the more tasks have been learned. You can make the network more plastic by reducing lambda (or reducing gamma might also work), which will probably increase the accuracies for the later tasks but reduce those for the earlier tasks. What the optimal values of these EWC hyper parameters are heavily depends on both the type and number of tasks. (The optimal values for different task protocols can differ by several orders of magnitude, see for example Appendix D here: https://arxiv.org/abs/1904.07734.)

Johswald commented 5 years ago

Thank you again!