DequanWang / tent

ICLR21 Tent: Fully Test-Time Adaptation by Entropy Minimization
https://arxiv.org/abs/2006.10726
MIT License
344 stars 43 forks source link

Why is entire model set to train, instead of only BatchNorm modules #17

Open hector-gr opened 1 year ago

hector-gr commented 1 year ago

Model is set to train mode in TENT so that BatchNorm modules use batch statistics at test time. However, this also sets other modules to train mode, for instance Dropout.

https://github.com/DequanWang/tent/blob/e9e926a668d85244c66a6d5c006efbd2b82e83e8/tent.py#L96-L110

It is also possible to set submodules to train mode only, and I believe this would achieve the desired behaviour for BatchNorm without affecting other modules. Is my understanding correct?

forgotton-wind commented 1 year ago

Notice this line: model.requires_grad_(False). It disables grad for every module. Then the for-loop below allows grad for BN.

hector-gr commented 1 year ago

As I understand it, .train() mode of a PyTorch Modules will set submodules like BatchNorm and Dropout to train mode (i.e. use batch statistics and update a running average of statistics, & drop some activations with certain probability; respectively). On the other hand, .requires_grad() will ultimately control the computations of grad for a certain module parameter.