GMvandeVen / continual-learning

PyTorch implementation of various methods for continual learning (XdG, EWC, SI, LwF, FROMP, DGR, BI-R, ER, A-GEM, iCaRL, Generative Classifier) in three different scenarios.
MIT License
1.54k stars 310 forks source link

Knowledge Distillation Loss #5

Closed Johswald closed 5 years ago

Johswald commented 5 years ago

Hey, In order to compute the cross entropy between the "soft" targets and the predictions you do the following: KD_loss_unnorm = (-targets_norm * log_scores_norm).mean() #--> average over batch

Wouldnt the correct cross entropy with mean over the batch be:

KD_loss_unnorm = (-targets_norm * log_scores_norm).sum(dim=1).mean()

GMvandeVen commented 5 years ago

Yes, you are right. Thank you for pointing this out. There should indeed be a summation over the classes before taking the mean over the batch. This has now been corrected in the code. Note that there was a similar issue with the binary classification loss (e.g., used in iCaRL) that now also has been corrected. My first impression is that it only has a small effect on the reported MNIST results, but I will test further. Especially for very long task protocols this correction might be quite important. Many thanks for your feedback!

Johswald commented 5 years ago

thank you for the nice code!