facebookresearch / unbiased-teacher

PyTorch code for ICLR 2021 paper Unbiased Teacher for Semi-Supervised Object Detection
https://arxiv.org/abs/2102.09480
MIT License
409 stars 84 forks source link

"batch norm layers are not updated as well" for teacher #47

Closed Divadi closed 2 years ago

Divadi commented 2 years ago

Hello! I keep referencing this work because it's a really good resource for SSL.

I noticed https://github.com/facebookresearch/unbiased-teacher/blob/226f8c4b7b9c714e52f9bf640da7c75343803c52/ubteacher/engine/trainer.py#L515 it's stated that "batch norm layers are not updated as well" for teacher. Why is this the case? From my understanding, I believe batch statistics are indeed updated even under torch.no_grad().

Also, have you tried setting the teacher to eval mode?

ycliu93 commented 2 years ago

Thanks for using our work! Any question or discussion is welcomed! 👍

Yes, I did try to change to eval mode for the teacher model, and it has the same mAP results if train mode is used.

This makes a concern that when the model is in train mode, the mean and variance of batch norm layers might be altered by feeding images even it is within no_grad.

So I checked and figured out that the batch normalization layer of ResNet is actually fixed for Detectron2. https://github.com/facebookresearch/detectron2/blob/23f61b8b5188cc9be048778cacf2e4855e62206d/detectron2/layers/batch_norm.py#L13

More reasons and explanations why it is fixed are listed in the following link. https://github.com/facebookresearch/maskrcnn-benchmark/issues/267

Anyway, you don't need to worry about whether this Teacher is in eval or train mode in our implementation. It leads to similar results from my previous experiments. The reason I still keep it in train mode is for easier analysis.

Divadi commented 2 years ago

Thank you for your comment - I understand!