z-hXu / ReCU

Pytorch implementation of our paper accepted by ICCV 2021 -- ReCU: Reviving the Dead Weights in Binary Neural Networks http://arxiv.org/abs/2103.12369
39 stars 8 forks source link

How does torch.clamp() reviving dead weights? #9

Open YKrisLiu opened 2 years ago

YKrisLiu commented 2 years ago

Appreciate your excellent work! I look up the code in binarized_modules.py and find that torch.clamp() is applied to constrain weights within ±Q_tau. In backwards, this approach stops "dead weights"(which >=Q_tau or <=-Q_tau) from updating. However, from what I understand, "reviving dead weights" should push those weights(which >=Q_tau or <=-Q_tau) towards 0 instead of stopping gradient descent (as clamp does). Do I have any inaccurate understanding of your work?

rainyBJ commented 2 years ago

Appreciate your excellent work! I look up the code in binarized_modules.py and find that torch.clamp() is applied to constrain weights within ±Q_tau. In backwards, this approach stops "dead weights"(which >=Q_tau or <=-Q_tau) from updating. However, from what I understand, "reviving dead weights" should push those weights(which >=Q_tau or <=-Q_tau) towards 0 instead of stopping gradient descent (as clamp does). Do I have any inaccurate understanding of your work?

From my perspective, maybe because the Q_tau is changing from a smaller value to a larger one, as mentioned in the paper, 0.85 -> 0.99. So during the whole training process, activations that are clamped to 0.85 in early iters are revived when the clipping threshold becomes larger.