mil-ad / snip

Pytorch implementation of the paper "SNIP: Single-shot Network Pruning based on Connection Sensitivity" by Lee et al.
MIT License
103 stars 15 forks source link

When creating weight masks you turn gradient off weights instead of the masks #4

Closed Dawson557 closed 3 years ago

Dawson557 commented 3 years ago

Line 36 in snip.py should be layer.weight_mask.requires_grad = False not as it is layer.weight.requires_grad = False which essentially ruins the entire model :P

mil-ad commented 3 years ago

SNIP computes gradients with respect to multiplicative masks and not the weights. The idea being that you can estimate the effect of removing weights rather than perturbing them.

Dawson557 commented 3 years ago

Aren't you getting the same gradient when collecting it in terms of the weights though? All I know is that since I made that switch my models actually work when loading them back in and get higher accuracies with only 20% weight retention using SNIP. I didn't look through all your code though so maybe you turn your weight gradients back on else where but the end result is the same. https://github.com/antonioverdi/MLReproChallenge

mil-ad commented 3 years ago

The SNIP() function returns the pruning "mask" and not the model. In fact it deep copies the model in the beginning of the function to make sure it doesn't change the model at all. Applying the mask during actual training is another matter. See: https://github.com/mil-ad/snip/blob/883c8050d88554a8d85083013d7d1ddd7ead54c4/train.py#L31

In a nutshell you'd have to set pruned weights to zero and make sure they remain zero by either (1) setting them to zero before every forwarding or (2) setting their gradients to zero after each call to .backward()