GMvandeVen / continual-learning

PyTorch implementation of various methods for continual learning (XdG, EWC, SI, LwF, FROMP, DGR, BI-R, ER, A-GEM, iCaRL, Generative Classifier) in three different scenarios.
MIT License
1.54k stars 310 forks source link

Why uses the prediction of all active classes? #7

Closed suanrong closed 5 years ago

suanrong commented 5 years ago

Thank you for the great implementation. I have a small question

In Class-IL settings, you use the prediction of all active classes to calculate loss.

Why?

GMvandeVen commented 5 years ago

Thanks for you question! Firstly I should note that for the Class-IL scenario, only the output-units corresponding to classes of future tasks (and that thus have not yet been seen) are excluded. That means that for the Class-IL scenarios always all classes seen so far 'active' and included. I now realise that the comments within the code are not very clear about this, apologies for that. But it's still a good question why the output-units corresponding to classes of future tasks are excluded. Especially as it is indeed true that in some cases always including all classes would lead to (somewhat) better performance. The main reason I decided against always including all classes, is that it is only possible to do so when it is a priori known how many classes there are. Only including classes that have been seen so far is thus more general. (Although in the code the entire network--including all output-units--is generated at the start, this is just an implementation issue, as it is also possible to add new output-units 'on-the-fly' when a new class is encountered.) But if for the problem you are interested in the total number of classes is known beforehand, you could indeed consider training with always the output-units of all classes included.

suanrong commented 5 years ago

Thanks for your reply. However , I think the right way is to calculate loss basing on the current classes, instead of all the active classes.

In your code, for the previous task, all the samples are negative and this kind of training do harm to the model severely and make the accuracy drop to zero very fast.

If you train only on the current classes, the accuracy of baseline becomes 15%~50%. The variance is large. (splitMNIST, Class-IL)

I think the regularization-based method will perform not that bad under this kind of training.

suanrong commented 5 years ago

Any response?

GMvandeVen commented 5 years ago

Sorry for the late reply, I hadn’t noticed that your comment had been edited. Yes, it seems I had misunderstood your initial question.

First, let me explain the reason why for the class-incremental learning scenario we decided to always set the output-units of all classes seen so far to ‘active’. For each of the three scenarios, we always train a model on what it will be tested on later. In the class-incremental learning scenario, the model will need to learn to chose between all classes seen so far, and so that is what we train it on. Although I don’t think there is necessarily a ‘right’ way here, this seems to me the most logical approach.

That said, you are right that in practice in certain circumstances it is possible to somewhat boost the performance for the class-incremental learning scenario by always only setting the output-units of classes in the current task to ‘active’. It is indeed interesting that this trick sometimes work, and I have recently played around with it as well. Essentially, I believe this trick depends on a training protocol with a (very) precise balance between the different tasks. In my opinion, this makes it questionable to what extent the difference between classes/tasks is really being ‘learned’ with this setup, for example because it is not robust against small variations in the training protocol (as exemplified by the large variance you mention).