TimDettmers / sparse_learning

Sparse learning library and sparse momentum resources.
MIT License
377 stars 45 forks source link

Need some help #15

Open jmandivarapu1 opened 4 years ago

jmandivarapu1 commented 4 years ago

Hi ,

I am trying to run the mnist code but I am not sure about the pruning rate and death rate values to use. so when I ran the code like this python main.py --data mnist --model lenet5 --save-features --bench --growth momentum --redistribution momentum But my accuracy is high and it's dropping too low again. Can you provide me the right arguments which I need to pass to the input?

Screen Shot 2019-11-26 at 12 50 02 PM
TimDettmers commented 4 years ago

The problem is that LeNet-5 has a bottleneck with very few connections and if you get by chance a configuration where no connection between one and the other layer exists no gradient flow happens and learning stops (only biases are learned in that case). You can avoid this by using random growth, which is able to find a pattern of connection after an epoch or you can try to run the same model with another random seed. From my experience, the training is usually stable after epoch 20 but with very few weights <5% one might need a couple of tries until a solid connection is established.

This is a problem solely with LeNet-5. Other network architectures do not have these extreme bottlenecks that LeNet-5 has and do not show this behavior.