ruizhoud / DistributionLoss

Source code for paper "Regularizing Activation Distribution for Training Binarized Deep Networks"
31 stars 6 forks source link

Implementation of [distrloss_layer.py] #2

Closed jack-kjlee closed 5 years ago

jack-kjlee commented 5 years ago

[distrloss_layer.py] distrloss1 = (torch.min(2 - mean - std, 2 + mean - std).clamp(min=0) 2).mean() + ((std - 4).clamp(min=0) 2).mean() distrloss2 = (mean 2 - std 2).clamp(min=0).mean()

According to the paper, it seems to be distrloss1:

Q1. Why 2 & 4 instead of 1 which is described in the paper?

distrloss2: Degeneration Loss: (mean 2 - std 2).clamp(min=0)

Q2. Why not (torch.abs(mean) - std).clamp(min=0) ** 2?

ruizhoud commented 5 years ago

Thanks for asking this question! We will update the comment in the code file about this.

Short answer: these numbers are some hyper-parameters you can change and try out, and should have similar but slightly different resulting accuracy. The hyper-parameters in the code are tested to have state-of-the-art performance on AlexNet.

Long answer: in the paper, the Fig. 4 shows the formulation of the distribution loss. It mainly tells us an intuition of how the loss can be formed. For example, the gradient mismatch loss is formulated as (1-\mu-k\sigma)+^2, which mainly tells us the loss should compute the positive difference between a constant and \mu+k\sigma (or -\mu+k*\sigma if the mean \mu is negative). Using 2 instead of 1 is a more relaxed form of the loss, which creates slightly better accuracy. Similar for the degeneration loss, the two variants (the one in the code and the one you write) function in the same way. Due to history reasons, we use the formulation in the code for the AlexNet experiments.

Please let me know if you have any questions!

jack-kjlee commented 5 years ago

Thank you for your kind reply. I am trying to run BNN-DL for cifar10 with 7 layer (6 convolution layers, 1 fully-connected layer). However, validation accuracy doesn't follow training with above formula and hyper parameters. In my case, training accuracy is 80%, but only 10~30% validation accuracy within 30 epoch. (e.g. training loss decrease to the level I want, but validation loss increase) But my implementation still works well if I set (disstribution loss = disstribution * 0). Could you give me some advice? or, could you commit code of BNN-DL for cifar10?

With below code:

k1 = 0.25 k2 = 1 distrloss1 = (torch.min(1 - mean - (k1std), 1 + mean - (k1std)).clamp(min=0) * 2).mean() + (((k1std) - 1).clamp(min=0) * 2).mean() distrloss2 = (torch.abs(mean) - (k2std)).clamp(min=0).mean()

jack-kjlee commented 5 years ago

The reason is momentum=None of nn.BatchNorm().. Why momentum=None?

ruizhoud commented 5 years ago

In PyTorch, momentum=None means using the whole epoch to compute an "accurate" batch mean and batch variance. If you check the learning rate schedule, the last epoch has 0 learning rate - this epoch is just getting an accurate batch mean and variance. It is normal that you see very bad validation accuracy prior to this. But after you run the epoch with lr = 0, the accuracy should be normal.

Btw, for debug purpose, if you do not want to train all the epochs before seeing the results, feel free to change the momentum to 0.1 for experimental purposes.

jack-kjlee commented 5 years ago

Thanks!