haitongli / knowledge-distillation-pytorch

A PyTorch implementation for exploring deep and shallow knowledge distillation (KD) experiments with flexibility
MIT License
1.84k stars 342 forks source link

Is student net really learn what teacher output? #9

Open zym1119 opened 6 years ago

zym1119 commented 6 years ago

I print the first 32 labels of train dataloader for teacher net and got: 14, 8, 29, 67, 59, 49, 73, 25, 4, 76, 11, 25, 82, 6, 11, 47, 28, 43, 40, 49, 27, 92, 62, 37, 64, 22, 38, 90, 14, 16, 27, 92 while the first 32 labels of train dataloader of student net but they are: 86, 40, 14, 73, 50, 43, 40, 27, 1, 51, 11, 47, 32, 76, 28, 83, 32, 4, 52, 77, 3, 64, 24, 36, 80, 93, 96, 72, 26, 75, 47, 79

So it seems that the output index of teacher net and student net are not the same at each batch.

haitongli commented 6 years ago

It makes sense that they're different. The idea of knowledge distillation goes beyond simply adding some supervision like final output labels. My understanding of knowledge distillation is that, the student is not mimicking the teacher's output labels, but learning on its own with training data, and being regularized with the "dark knowledge" from teacher. If you take a look at the KD loss function, it's a joint of hard labels (training samples) and soft labels (teacher outputs). Hinton's paper had an insightful discussion around this.

zym1119 commented 6 years ago

If alpha=0.95, t=6, then alphatt would be far larger than 1-alpha, i don't think the student is not mimicking the teachers' output, the loss is almost depending on the kl-divergence between the student and teacher. Have you tried training a student on ImageNet?

haitongli commented 6 years ago

It might be interesting to have a hyperparameter explorations (using the search_hyperparams.py) for alpha and T (I had done that for my course project, but not intensively due to time/resource limitation). If we use a near-one alpha (forcing student to mimic teacher more) with an improper temperature, sometimes we would notice a sharp drop in accuracy. Sometimes though, even with a small alpha, much knowledge could be distilled from teacher into student (measured by accuracy improvement). My point was simply that, the loss contribution from distillation part could be more complicated than the intuitive sense of "mimicking" teacher's output labels. It also relates to the bias-variance tradeoff from a traditional ML perspective.

No, I haven't touched ImageNet due to resource constraint...

zym1119 commented 6 years ago

thx for answering, i need to have a detailed look at your report and do some experiments

zym1119 commented 6 years ago

I split the loss into two parts, one is the cross entropy between outputs and teacher outputs with temperature T, the other is the cross entropy between outputs and labels. the first loss is called "soft" and the other is "hard" In your code, due to the mis-align of teacher output and student output, the "soft" loss is always a const during training with value 0.045, and the "hard" loss is optimized from 0.198 to a relatively low value. This means the network is always learning from the target while not the teacher, the "soft" loss is always a noisy for student, i guess this became some sort of regularization and cause the student network to have slightly better generalization than before.

DavidBegert commented 6 years ago

Am I wrong that the teacher outputs are supposed to be aligned with the student outputs?

michaelklachko commented 6 years ago

Here's how I did it:

for i in range(num_train_batches):
    input = train_inputs[i*batch_size:(i+1)*batch_size]
    label = train_labels[i*batch_size:(i+1)*batch_size]
    output = model(input)
        teacher_output = teacher_model(input)
    loss = nn.KLDivLoss()(F.log_softmax(output / T, dim=1), F.softmax(teacher_output / T, dim=1)) * alpha * T * T + F.cross_entropy(output, label) * (1 - alpha)

I see a consistent improvement of ~1% on CIFAR-10 with T=6 and alpha=0.9.

chenxshuo commented 5 years ago

Am I wrong that the teacher outputs are supposed to be aligned with the student outputs?

I think these two outputs should be aligned. It seems that @peterliht has recorded the teacher output in the beginning and reuse the output during training student by the index through enumerate(dataloader). But the dataloader shuffles everytime so the results are not the same for teacher and student.