TimDettmers / sparse_learning

Sparse learning library and sparse momentum resources.
MIT License
377 stars 45 forks source link

Using dynamic growth & pruning? #20

Open varun19299 opened 3 years ago

varun19299 commented 3 years ago

Hi, is it possible to use dynamic growth and pruning currently by just updating the masks each step?

I'm looking to implement something like RigL (Evci et al. 2020).

TimDettmers commented 3 years ago

Yes, theoretically you can call the pruning function at each mini-batch iteration. If you look at the code, it is currently only called after the end of each epoch. You just need to put this function into the training loop to achieve pruning/growth at each step.

Another option that is already baked in, but comes with predefined static behavior is setting the prune_every_k_steps variable. Setting it to 1 would execute the prune/regrowth cycle with every mini-batch.

varun19299 commented 3 years ago

Thank you for your reply!

I'm having a bit of trouble understanding the working and objective of the calc_growth_redistribution method in core.py. A few questions:

TimDettmers commented 3 years ago

Thanks for your comment. The method determines the redistribution of weights. There is the problem of what you do if weights are redistributed to layers that are already full (and cannot regrow weights) or if more weights are regrown that a layer can fit. I could keep track of these, but I found it easier and more general to anneal the redistribution over time (1000 iterations).

The residual is the overflow from full layers that are too full, and it is redistributed up to 1000 iterations. It isn't easy to redistribute the weights in some cases, and the annealing procedure does not converge in 1000 iterations. In this case, the best solution after 1000 iterations is taken, but this solution might not be 100% proportional to the metric used to determine the redistribution fractions.

I hope this makes it a bit clear, it is definitely a confusion function, and I see that I forgot to clean up some artifacts as you have pointed out. Let me know if you have more questions.

varun19299 commented 3 years ago

Hi Tim,

Thanks again for your reply! I have a few more questions :)

  1. I think the ERK initialisation may not be correct: there isn't a check to see if a layer's capacity is exhausted, i.e., p_{ERK, layer_i} > 1. (For instance, RigL marks such layers as dense and excludes them from the ERK distributed set).

  2. Also, is this correct? https://github.com/TimDettmers/sparse_learning/blob/f99c2f2ee1e89a786e942c73c054c11912866488/sparselearning/core.py#L193

(probably be abs(current_params - target_params) < tolerance?)

  1. Even with this change, since the existing code doesn't explicitly check if capacity is reached, actual density is much lesser than the input density.

Below is a comparison of the existing snippet vs RigL's implementation. Since there is no check on capacity, the actual sparsity is lower than the intended one. In the below output, intended density was 0.2 (or 80% sparsity).

INFO:root:ERK block1.layer.0.conv1.weight: torch.Size([32, 16, 3, 3]) prob 0.7958154602321671
INFO:root:ERK block1.layer.0.conv2.weight: torch.Size([32, 32, 3, 3]) prob 0.5158063168171453
INFO:root:ERK block1.layer.0.convShortcut.weight: torch.Size([32, 16, 1, 1]) prob 6.631795501934726
INFO:root:ERK block1.layer.1.conv1.weight: torch.Size([32, 32, 3, 3]) prob 0.5158063168171453
INFO:root:ERK block1.layer.1.conv2.weight: torch.Size([32, 32, 3, 3]) prob 0.5158063168171453
INFO:root:ERK block1.layer.2.conv1.weight: torch.Size([32, 32, 3, 3]) prob 0.5158063168171453
INFO:root:ERK block1.layer.2.conv2.weight: torch.Size([32, 32, 3, 3]) prob 0.5158063168171453
INFO:root:ERK block2.layer.0.conv1.weight: torch.Size([64, 32, 3, 3]) prob 0.37580174510963443
INFO:root:ERK block2.layer.0.conv2.weight: torch.Size([64, 64, 3, 3]) prob 0.2468501659053481
INFO:root:ERK block2.layer.0.convShortcut.weight: torch.Size([64, 32, 1, 1]) prob 3.2495797959480157
INFO:root:ERK block2.layer.1.conv1.weight: torch.Size([64, 64, 3, 3]) prob 0.2468501659053481
INFO:root:ERK block2.layer.1.conv2.weight: torch.Size([64, 64, 3, 3]) prob 0.2468501659053481
INFO:root:ERK block2.layer.2.conv1.weight: torch.Size([64, 64, 3, 3]) prob 0.2468501659053481
INFO:root:ERK block2.layer.2.conv2.weight: torch.Size([64, 64, 3, 3]) prob 0.2468501659053481
INFO:root:ERK block3.layer.0.conv1.weight: torch.Size([128, 64, 3, 3]) prob 0.18237437630320497
INFO:root:ERK block3.layer.0.conv2.weight: torch.Size([128, 128, 3, 3]) prob 0.12066183482686793
INFO:root:ERK block3.layer.0.convShortcut.weight: torch.Size([128, 64, 1, 1]) prob 1.608210409219171
INFO:root:ERK block3.layer.1.conv1.weight: torch.Size([128, 128, 3, 3]) prob 0.12066183482686793
INFO:root:ERK block3.layer.1.conv2.weight: torch.Size([128, 128, 3, 3]) prob 0.12066183482686793
INFO:root:ERK block3.layer.2.conv1.weight: torch.Size([128, 128, 3, 3]) prob 0.12066183482686793
INFO:root:ERK block3.layer.2.conv2.weight: torch.Size([128, 128, 3, 3]) prob 0.12066183482686793
INFO:root:ERK fc.weight: torch.Size([10, 128]) prob 7.321502234135937
INFO:root:Overall sparsity 0.18062481420927468
INFO:root:========
INFO:root:Sparsity of var:fc.weight had to be set to 0.
INFO:root:Sparsity of var:block1.layer.0.convShortcut.weight had to be set to 0.
INFO:root:Sparsity of var:block2.layer.0.convShortcut.weight had to be set to 0.
INFO:root:Sparsity of var:block3.layer.0.convShortcut.weight had to be set to 0.
INFO:root:layer: block1.layer.0.conv1.weight, shape: torch.Size([32, 16, 3, 3]), density: 0.8874813710879286
INFO:root:layer: block1.layer.0.conv2.weight, shape: torch.Size([32, 32, 3, 3]), density: 0.5752194071866203
INFO:root:layer: block1.layer.0.convShortcut.weight, shape: torch.Size([32, 16, 1, 1]), density: 1.0
INFO:root:layer: block1.layer.1.conv1.weight, shape: torch.Size([32, 32, 3, 3]), density: 0.5752194071866203
INFO:root:layer: block1.layer.1.conv2.weight, shape: torch.Size([32, 32, 3, 3]), density: 0.5752194071866203
INFO:root:layer: block1.layer.2.conv1.weight, shape: torch.Size([32, 32, 3, 3]), density: 0.5752194071866203
INFO:root:layer: block1.layer.2.conv2.weight, shape: torch.Size([32, 32, 3, 3]), density: 0.5752194071866203
INFO:root:layer: block2.layer.0.conv1.weight, shape: torch.Size([64, 32, 3, 3]), density: 0.4190884252359663
INFO:root:layer: block2.layer.0.conv2.weight, shape: torch.Size([64, 64, 3, 3]), density: 0.2752835734393112
INFO:root:layer: block2.layer.0.convShortcut.weight, shape: torch.Size([64, 32, 1, 1]), density: 1.0
INFO:root:layer: block2.layer.1.conv1.weight, shape: torch.Size([64, 64, 3, 3]), density: 0.2752835734393112
INFO:root:layer: block2.layer.1.conv2.weight, shape: torch.Size([64, 64, 3, 3]), density: 0.2752835734393112
INFO:root:layer: block2.layer.2.conv1.weight, shape: torch.Size([64, 64, 3, 3]), density: 0.2752835734393112
INFO:root:layer: block2.layer.2.conv2.weight, shape: torch.Size([64, 64, 3, 3]), density: 0.2752835734393112
INFO:root:layer: block3.layer.0.conv1.weight, shape: torch.Size([128, 64, 3, 3]), density: 0.20338114754098363
INFO:root:layer: block3.layer.0.conv2.weight, shape: torch.Size([128, 128, 3, 3]), density: 0.13456025418115583
INFO:root:layer: block3.layer.0.convShortcut.weight, shape: torch.Size([128, 64, 1, 1]), density: 1.0
INFO:root:layer: block3.layer.1.conv1.weight, shape: torch.Size([128, 128, 3, 3]), density: 0.13456025418115583
INFO:root:layer: block3.layer.1.conv2.weight, shape: torch.Size([128, 128, 3, 3]), density: 0.13456025418115583
INFO:root:layer: block3.layer.2.conv1.weight, shape: torch.Size([128, 128, 3, 3]), density: 0.13456025418115583
INFO:root:layer: block3.layer.2.conv2.weight, shape: torch.Size([128, 128, 3, 3]), density: 0.13456025418115583
INFO:root:layer: fc.weight, shape: torch.Size([10, 128]), density: 1.0
INFO:root:Overall sparsity 0.2

Here's the source to produce this output.

Adding a threshold (something like growth = max(prob ,1)*weight.numel() causes the code to take a long time to converge.

TimDettmers commented 3 years ago

Great catch! Would you mind submitting a pull request for this? I feel like you are able to quickly pinpoint and fix this issue.

varun19299 commented 3 years ago

Sure, I would be happy to contribute.

Would you prefer adding RigL's implementation of ERK for this? It does better than trying to tune epsilon for a given sparsity (as seen in the outputs above).

varun19299 commented 3 years ago

We used @TimDettmers's sparselearning to base our code for RigL-reproducibility.

Our code has also deviated significantly since then, but I could patch in the ERK initialisation change if its still welcome.