TimDettmers / sparse_learning

Sparse learning library and sparse momentum resources.
MIT License
377 stars 45 forks source link

Maintaining sparsity during backprop #16

Closed ArvindSubramaniam closed 4 years ago

ArvindSubramaniam commented 4 years ago

Hi! I went through your code and found that you had applied the mask to the entire network using the apply_mask function. However, are you maintaining the sparsity during backprop as well? For instance, are connections which were masked during forward prop reactivated during backprop, or are you still masking them? I was unable to find the code for this. Could you point to the code where you are addressing this?

TimDettmers commented 4 years ago

Yes, this is a bit tricky. I thought about this and masking gradients is difficult to generalize for all architectures. What I do instead is to mask the weights before training, for example: [0.3, 0.1, 0.2] with mask [1, 0, 1] becomes [0.3, 0.0, 0.2] and then I apply the mask on the weights directly after each weight update. I do not apply the mask to gradients!

What does this look like in practice? For example, you have the weight update [0.1, 0.1, 0.1] then your weight becomes [0.4, 0.1, 0.3], then apply the mask [1, 0, 1] and receive [0.4, 0.0, 0.3]. This is the same as masking the incoming gradient. Note, that the errors which are backpropagated should not be masked and this is exactly what this procedure does.

Does this make sense?

ArvindSubramaniam commented 4 years ago

Thanks for your answer! So basically, the masked weights are reactivated, but in the next forward prop iteration, they are masked again, right? So this would mean that the sparsity is reduced after every backprop, but is restored again in the forward prop?

What if you used hooks to mask the gradients? Could you do something like this:

for m in model.modules():
  if isinstance(m,nn.Conv2d):
    m.weight.register_hook(lambda grad: grad*mask)

You would have to ensure that you are simultaneously iterating through the masks and their corresponding weight matrices, which might be a little tricky though.

TimDettmers commented 4 years ago

Yes your interpretation is correct.

The solution that you demonstrate is similar to a solution that I thought of, but the problem with that solution is that you need an if-clause for each class that can possibly be used. With the current implementation you can use GRUs, LSTMs, transformers, 1D, 2D, and 3D convolution without any changes. The main advantage with that is that you do not have to go through every possible class PyTorch class that can have sparse weights. Logically, it should produce the same results. Would you agree?

ArvindSubramaniam commented 4 years ago

Agreed. In this case, since the pruning is iterative, the results are exactly the same. However, in one-shot pruning, I guess you would have to mask the gradients.

Thanks for your time!

pranaymodukuru commented 4 years ago

I guess, for masking the gradients we could multiply the same masks on the parameter gradients before optimizer.step().

Please let me know if there are any disadvantages to this approach.

Note, that the errors which are backpropagated should not be masked and this is exactly what this procedure does.

And could you please explain why the above statement should be followed.