changlin31 / DNA

(CVPR 2020) Block-wisely Supervised Neural Architecture Search with Knowledge Distillation
235 stars 35 forks source link

Doubts about the definition #18

Closed fusiming3 closed 3 years ago

fusiming3 commented 3 years ago

Thanks for your codes! I have a question that about the definition of _potential(supernet,teacher=None,eval_loader=None, loss_fn=None,args=None, stage=None,stage_model_pool=None, use_target=False,reset_data=None) in DNA/searching/dna/distill_train.py / Why the BN layer in the teacher and supernet should be trained separately?In my mind, the BN layer is fixed in evaluate stage. What is more, the method of searching the model pool is complex. How the model pool is updated in definition of _train_stage in DNA/searching/dna/distill_train.py ?

changlin31 commented 3 years ago

Hi, @fusiming3

  1. BN layers are not trained but recalibrated before evaluation. The BN statistics (running_mean μ and runningvar ) accumulated during supernet training are not correct. This is because the inputs could come from different paths and also the statistics are not updated in every step. A common technique to fix this is BN recalibration, which is defined here: https://github.com/changlin31/DNA/blob/570c708c950e8bf0a7d5f3dc949163ceb5e49b0a/searching/timm/utils.py#L338 When performing BN recalibration (or BN correction), the statistics (μ and ) of BN layers are reset and recalculated with part of the training set while the weights (γ) and the bias (β) remain unchanged. After that, the model is set to eval mode and perform evaluation normally. This method is called multiple times in _potential().

  2. model_pool is a list containing encodings of all possible paths. This list is generated by https://github.com/changlin31/DNA/blob/570c708c950e8bf0a7d5f3dc949163ceb5e49b0a/searching/dna/distill_train.py#L865 and is used when sampling paths for training. By default, guide_input=True, so the model pool is discarded and regenerated before each stage. https://github.com/changlin31/DNA/blob/570c708c950e8bf0a7d5f3dc949163ceb5e49b0a/searching/dna/distill_train.py#L50-L51