zcapjdb / UKAEAGroupProject

1 stars 0 forks source link

Continual learning and active learning #39

Open lorenzozanisi opened 2 years ago

lorenzozanisi commented 2 years ago

@thandi1908 @zcapjdb please read before tomorrow

Suppose we train on dataset S1, then we enrich it with more data S2, S1 and S2 come from the same distribution. The training performance will of course change, and the model trained on S1 won't work on S1 data only anymore. While on average there might be some drift in the performance on the first model on S1, this is expected - NNs overfit easily!! So it's actually a good thing that on S1 the performance deteriorates.

The point is, you can't assess model performance based on the training set! The holdout set is the one that should always be looked at, if S1 and S2 are from the same distribution. This is probably why the shrink-perturb method doesn't matter too much for us.

BUT, if S1 and S2 are not from the same distribution, we would like our model to remember about task1 while training task2 - this is continual learning. This is for example if you want to predict things for turbulence from data for different tokamaks, having the same input variables. This entails two validation sets, one for task 1 and one for task2. The model is trained on task 1, but then the validation of task 1 is monitored as the model is trained on task 2. Sure you can have two models, one trained on S1 and then fine tune it on S2. But having just one would be better. Also, with AL you can pick only the most relevant data from S2, which means you'd need to run qualikiz only a few times. This is what I'd like to do for the broader paper.

So we can start only with the AL pipeline for the full dataset, no shrink-perturb or similar (but we need a better acquisition than only MSE). Then we can move to using AL for continual learning from one task to another - train on S1 with AL, then train only on S2 still with AL, but using shrink-perturb or other techniques like elastic weight consolidation https://aiperspectives.springeropen.com/articles/10.1186/s42467-021-00009-8 This paper is a good way to start, see page 7 and 8 (they also detect a change in probability distribution of data, but we won't need that as we have labels for it)

lorenzozanisi commented 2 years ago

See the train loss for the AL pipeline (which is disgusting) vs the validation loss (which keeps going down) - each vertical line is a new training, this is basically what we are seeing as well for the training data but not for the validation I believe (potentially because your validation set changed every iteration) image

Source: https://ui.adsabs.harvard.edu/abs/2020MNRAS.491.1554W/abstract