JiahuiYu / slimmable_networks

Slimmable Networks, AutoSlim, and Beyond, ICLR 2019, and ICCV 2019
Other
914 stars 131 forks source link

How did you combine the hard label loss and the Distillation loss #12

Closed VectorYoung closed 5 years ago

VectorYoung commented 5 years ago

Hi Jiahui, I am reproducing the USlimeNet. How did you combine the hard label loss(loss with the max width) with the inplace distillation loss? I just directly sum them. But I notice that the inplace distillation loss is much smaller than the hard label loss. So the gradients for subnetworks are very small compared to the max_width network. Now my max network is training well but the subnetworks improved little. Following is my training procedure

           width_mult_list = []
            # first do max_width
            max_width = FLAGS.width_mult_range[1]
            model.apply(lambda m: setattr(m, 'width_mult', max_width))
            max_output = model(input)
            loss = torch.mean(criterion(max_output, target))
            loss.backward()
            max_output = max_output.detach()
            # do other widths
            min_width = FLAGS.width_mult_range[0]
            width_mult_list = [min_width]
            sampled_width = list(np.random.uniform(FLAGS.width_mult_range[0], 
                                              FLAGS.width_mult_range[1], 2))
            width_mult_list.extend(sampled_width)
            for width_mult in sorted(width_mult_list, reverse=True):
                model.apply(
                    lambda m: setattr(m, 'width_mult', width_mult))
                output = model(input)
                loss = torch.nn.KLDivLoss()(F.log_softmax(output, dim=1), F.softmax(max_output, dim=1))
                loss.backward()

So are you using weighted sum for the losses? Or how did you combine them?

JiahuiYu commented 5 years ago

I use unweighted sum.

I would suggest to check if you have correctly implemented distillation loss, since I don’t understand why gradients will be much smaller.

VectorYoung commented 5 years ago

Here is how I implement KD loss loss = torch.nn.KLDivLoss()(F.log_softmax(output, dim=1), F.softmax(max_output, dim=1)) output is the sub-network output, max_output is the full-network output. I use log_softmax for output because the PyTorch doc says KLDivLoss takes log-probability as input. And I think it is reasonable that the KD loss is smaller. Because at early stage, the output of full-network is very random(since not well trained), so the output distance between the sub-network and full-network is small.

VectorYoung commented 5 years ago

Hi @JiahuiYu Is this the way you implement the distillation loss? Or could you please tell me how did you implement the distillation loss?

Here is how I implement KD loss loss = torch.nn.KLDivLoss()(F.log_softmax(output, dim=1), F.softmax(max_output, dim=1)) output is the sub-network output, max_output is the full-network output. I use log_softmax for output because the PyTorch doc says KLDivLoss takes log-probability as input. And I think it is reasonable that the KD loss is smaller. Because at early stage, the output of full-network is very random(since not well trained), so the output distance between the sub-network and full-network is small.

JiahuiYu commented 5 years ago

Hi @VectorYoung

I did not use the torch.nn.KLDivLoss API, instead use the matrix multiplication to sum over.

I had a look of torch.nn.KLDivLoss, which computes y (log y - x). In this equation, y log y has no effects on producing gradients. This is maybe the reason why you think KD loss is smaller?

On my side, the training is quite stable even at the first epoch. Can you report your accuracy of all sub-networks in the first five epochs?

VectorYoung commented 5 years ago

Hi @JiahuiYu

I think I figure out. The API says the default mode is 'mean', but it element-wise mean not batch-mean. So it is not really calculating the KL divergence(in my case it is 1000 times smaller). You need to assign reduction='batchmean' to get the KL divergence. However, the 'mean' will be changed to be the same with 'batchmean' in the next major release. I will report my result after I revise it.

BTW, why do you only use KD loss for sub-networks. Is it better than only true label and (true label + KDloss)?

Thanks a lot for your help.

JiahuiYu commented 5 years ago

Great, sounds reasonable.

I have tried several variants of (true label + distillation), but results are worse. It is mentioned in the paper in a short sentence. But you are very welcome to explore on it.

JiahuiYu commented 5 years ago

14