liuzechun / MetaPruning

MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning. In ICCV 2019.
MIT License
351 stars 74 forks source link

about mobilenet_v2 part #4

Closed zj19921221 closed 5 years ago

zj19921221 commented 5 years ago

I have two questions to ask you: 1)in the mobilenet_v2.py why set the parameter “affine” false, what will happen if set to be True; 2)when i train the model with my own data, I found that the acc of my validation set always keep same; forward to your answer!!!

zj19921221 commented 5 years ago

I have another question: how to judge when should I stop training pruningNet; I trained the pruning net but I found that that the validation set acc is very low....

liuzechun commented 5 years ago

1) Affine is set to False because it may help stabilize training. But I also did another experiment with affine being True, and not obvious difference are observed. You can also try setting it True. 2) The validation accuracy is always near zero in our experiment, too. That is because the BatchNorm is switched to the evaluation mode at inference time. We handled this in the search code to bring the validation accuracy to normal. The validation accuracy is not important, we only want to see if it is NAN. If not NAN, it means the learning rate is not too high, the model is being trained properly. 3) We always look at the training accuracy. When it is higher than half of the accuracy as the original network to be pruned (normally 40%), we think it could be high enough for search. Empirically, we train the PruningNet using 1/4 of the epochs of training a normal network from scratch.

zj19921221 commented 5 years ago

Affine is set to False because it may help stabilize training. But I also did another experiment with affine being True, and not obvious difference are observed. You can also try setting it True. The validation accuracy is always near zero in our experiment, too. That is because the BatchNorm is switched to the evaluation mode at inference time. We handled this in the search code to bring the validation accuracy to normal. The validation accuracy is not important, we only want to see if it is NAN. If not NAN, it means the learning rate is not too high, the model is being trained properly. We always look at the training accuracy. When it is higher than half of the accuracy as the original network to be pruned (normally 40%), we think it could be high enough for search. Empirically, we train the PruningNet using 1/4 of the epochs of training a normal network from scratch.

thanks for your reply. I almost unserstand why the acc in training phase is very low after I read the the code at line 108-133 ; but another quenstion : why not just totally remove the validation part when training pruningNet? thank you very much.

zj19921221 commented 5 years ago

sorry to interupt you agian! I have another opinion of the code in search:when in the training phase,we should set the parameter "track_running_stats" of BN layer to be false; And In code of search.py at function "infer" of recalibrating batchnorm for each selected pruned network, we should set the "track_running_stats" to be True ;;; Do you think it will be better? forward to your anwer;

liuzechun commented 5 years ago

why not just totally remove the validation part when training pruningNet?

Because we want to make sure the validation accuracy is not NAN during the training. It is mainly for debug.

I have another opinion of the code in search:when in the training phase,we should set the parameter "track_running_stats" of BN layer to be false; And In code of search.py at function "infer" of recalibrating batchnorm for each selected pruned network, we should set the "track_running_stats" to be True ;;;

That is related to the usage of BN in PyTorch. There are two mode of BN, train() and eval(). In train() mode, the mean and variance in BN is only calculated w.r.t. the current batch. In eval() mode, the mean and variance uses the historical statistics of the moving average of mean and variance, which is calculated in the previous train() mode, by stating "track_running_stats=Ture". We don't need the historical moving average of mean and variance during training time, because it combines multiple different channel number choices in each layer and is imprecise. We recalculate the statistics of the BN when the Pruned Net structure is defined, thus we set "track_running_stats=Ture" at inference time. So that the BN stores the correct historical mean and variance for correct inferring the Pruned Net in the next coming eval() mode.