Closed VectorYoung closed 5 years ago
I use unweighted sum.
I would suggest to check if you have correctly implemented distillation loss, since I don’t understand why gradients will be much smaller.
Here is how I implement KD loss
loss = torch.nn.KLDivLoss()(F.log_softmax(output, dim=1), F.softmax(max_output, dim=1))
output is the sub-network output, max_output is the full-network output. I use log_softmax for output because the PyTorch doc says KLDivLoss takes log-probability as input.
And I think it is reasonable that the KD loss is smaller. Because at early stage, the output of full-network is very random(since not well trained), so the output distance between the sub-network and full-network is small.
Hi @JiahuiYu Is this the way you implement the distillation loss? Or could you please tell me how did you implement the distillation loss?
Here is how I implement KD loss
loss = torch.nn.KLDivLoss()(F.log_softmax(output, dim=1), F.softmax(max_output, dim=1))
output is the sub-network output, max_output is the full-network output. I use log_softmax for output because the PyTorch doc says KLDivLoss takes log-probability as input. And I think it is reasonable that the KD loss is smaller. Because at early stage, the output of full-network is very random(since not well trained), so the output distance between the sub-network and full-network is small.
Hi @VectorYoung
I did not use the torch.nn.KLDivLoss
API, instead use the matrix multiplication to sum over.
I had a look of torch.nn.KLDivLoss
, which computes y (log y - x)
. In this equation, y log y
has no effects on producing gradients. This is maybe the reason why you think KD loss is smaller?
On my side, the training is quite stable even at the first epoch. Can you report your accuracy of all sub-networks in the first five epochs?
Hi @JiahuiYu
I think I figure out. The API says the default mode is 'mean', but it element-wise mean not batch-mean. So it is not really calculating the KL divergence(in my case it is 1000 times smaller). You need to assign reduction='batchmean' to get the KL divergence. However, the 'mean' will be changed to be the same with 'batchmean' in the next major release. I will report my result after I revise it.
BTW, why do you only use KD loss for sub-networks. Is it better than only true label and (true label + KDloss)?
Thanks a lot for your help.
Great, sounds reasonable.
I have tried several variants of (true label + distillation), but results are worse. It is mentioned in the paper in a short sentence. But you are very welcome to explore on it.
Hi Jiahui, I am reproducing the USlimeNet. How did you combine the hard label loss(loss with the max width) with the inplace distillation loss? I just directly sum them. But I notice that the inplace distillation loss is much smaller than the hard label loss. So the gradients for subnetworks are very small compared to the max_width network. Now my max network is training well but the subnetworks improved little. Following is my training procedure
So are you using weighted sum for the losses? Or how did you combine them?