imirzadeh / Teacher-Assistant-Knowledge-Distillation

Using Teacher Assistants to Improve Knowledge Distillation: https://arxiv.org/pdf/1902.03393.pdf
MIT License
256 stars 47 forks source link

Cifar 10 training - resenet 110 to resnet 8 #10

Closed karanchahal closed 4 years ago

karanchahal commented 5 years ago

Hello, I just trained a resnet110 and distilled it to a resnet8 using the baseline knowledge distillation. The teacher model got an accuracy of ~80% while the distilled version got an accuracy of 73%.

Upon using TAKD, the resnet100 distills a resnet20 which in turn distills a resnet8. THe accuracy of the resnet20 was 79% after distillation but the accuracy of the resnet8 after distillation via resnet20 was still 73%.

This shows that there is no improvement by using TAKD. Am i doing something wrong ?

If you have thoughts about this experiment, I would love to know.

Thanks for the paper and the code !

karanchahal commented 5 years ago

I used the following config: {'T_student': 1, 'lambda_student': 0.05}

Are there a better hyperparameters than this for cifar10 ?

imirzadeh commented 5 years ago

Hi,

First, are you sure you are training correctly? On "cifar10", without any knowledge distillation, Resnet-100 should get around 93% accuracy. Resnet-20 should get around 91.5% and Resnet-8 around 88.5%. Maybe you are training for only a few epochs. You can check also the numbers reported here for cifar10.

With knowledge distillation from Resnet110, Resnet-20's accuracy will become 92.82% (~1.3% improvement) with these hyper parameters: { 'lambda_student': 0.4 , 'T_student': 10, 'seed': 31} The weight file for Resnet20 is also available here.

Now, with this Resnet-20 as our TA, if we train the student, the accuracy will be 88.9% (compared to 88.6 for normal knowledge distillation). { 'lambda_student': 0.95 , 'T_student': 5, 'seed': 31} The weight file for Resnet8 is also available here

However, since my code may not be fully deterministic, you might also try these parameters with other seeds.

Hope this helps.

imirzadeh commented 5 years ago

If you are interested, you can check with these numbers I got for cifar10.

model accuracy
resnet 110 93.87
resnet 56 93.27
resnet 32 92.64
resnet 26 92.48
resnet 20 91.67
resnet 14 91.05
resnet 8 88.52
110 -> 56 94.02
110 -> 32 93.52
110 -> 20 92.82
110 -> 14 92.16
110 -> 8 88.65
110 -> 56 -> 8 88.7
110 -> 32 -> 8 88.73
110 -> 20 -> 8 88.9
110 -> 14 -> 8 88.98
karanchahal commented 4 years ago

Hello, thank you for this information. it was really enlightning. I have a question regarding the loss function for KD:

 loss_KD = nn.KLDivLoss()(F.log_softmax(output / T, dim=1), F.softmax(teacher_outputs / T, dim=1)) 

Why aren't we using the log softmax here for the teacher outputs ? Is this by design because I can't understand the reason behind this.

Also is there a way that we can find the optimal values of lambda and temperature without going through the hyper parameter tuning route ?

imirzadeh commented 4 years ago

I really haven't thought of using log_softmax. Mainly because of the previous works that were all softmax. But I think using log_softmax will make the logits more similar to each other and maybe restrict the knowledge transfer. If you want to control the logits, you can use T(temperature). But this is just a thought.

About hyperparameters(T and lambda), I was trying to find a relationship but I couldn't find anything worthy except that the closer the teacher and student, the higher the lambda becomes usually(not always).

karanchahal commented 4 years ago

alright, thanks for the input !