Closed Dawson557 closed 3 years ago
SNIP computes gradients with respect to multiplicative masks and not the weights. The idea being that you can estimate the effect of removing weights rather than perturbing them.
Aren't you getting the same gradient when collecting it in terms of the weights though? All I know is that since I made that switch my models actually work when loading them back in and get higher accuracies with only 20% weight retention using SNIP. I didn't look through all your code though so maybe you turn your weight gradients back on else where but the end result is the same. https://github.com/antonioverdi/MLReproChallenge
The SNIP()
function returns the pruning "mask" and not the model. In fact it deep copies the model in the beginning of the function to make sure it doesn't change the model at all. Applying the mask during actual training is another matter. See: https://github.com/mil-ad/snip/blob/883c8050d88554a8d85083013d7d1ddd7ead54c4/train.py#L31
In a nutshell you'd have to set pruned weights to zero and make sure they remain zero by either (1) setting them to zero before every forwarding or (2) setting their gradients to zero after each call to .backward()
Line 36 in snip.py should be
layer.weight_mask.requires_grad = False
not as it islayer.weight.requires_grad = False
which essentially ruins the entire model :P