facebookresearch / unbiased-teacher

PyTorch code for ICLR 2021 paper Unbiased Teacher for Semi-Supervised Object Detection
https://arxiv.org/abs/2102.09480
MIT License
415 stars 83 forks source link

update student model #27

Closed hachreak closed 3 years ago

hachreak commented 3 years ago

Hi everybody, I would try to update the student model instead to update the teacher. I copied the function and inverted self.model with self.model_teacher but in a distributed run I received this error because the key is not in teacher_model_dict.keys():

Exception during training:
Traceback (most recent call last):
  File "/hpc/home/leonardo.rossi/src/unbiased-teacher/ubteacher/engine/trainer.py", line 339, in train_loop
    self.run_step_full_semisup()
  File "/hpc/home/leonardo.rossi/src/unbiased-teacher/ubteacher/engine/trainer.py", line 908, in run_step_full_semisup
    self._update_student_model(0.0)
  File "/hpc/home/leonardo.rossi/.conda/envs/mmdetection23/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/hpc/home/leonardo.rossi/src/unbiased-teacher/ubteacher/engine/trainer.py", line 927, in _update_student_model
    raise Exception("{} is not found in teacher model".format(key))
Exception: module.backbone.fpn_lateral2.weight is not found in teacher model

Do you know how should I do? Thanks a lot! :smile:

hachreak commented 3 years ago

Maybe I found how to resolve:

    @torch.no_grad()
    def _update_student_model(self, keep_rate=0.996):
        teacher_model_dict = self.model_teacher.state_dict()

        prefix = ''
        if comm.get_world_size() > 1:
            prefix = 'module.'

        new_student_dict = OrderedDict()
        for key, value in self.model.state_dict().items():
            tkey = key
            if key.startswith('module.'):
                tkey = key[7:]
            if tkey in teacher_model_dict.keys():
                new_student_dict[key] = (
                    teacher_model_dict[tkey] * (1 - keep_rate) + value * keep_rate
                )
            else:
                raise Exception("{} is not found in teacher model".format(key))

        self.model.load_state_dict(new_student_dict)

But I have a strange behavior: After the teacher -> student sync the performance of the student start to go down! Do you have any idea why? Because I do something similar to the burn up (when the teacher is fully sync with the student):

self._update_teacher_model(keep_rate=0.00)
ycliu93 commented 3 years ago

Hello @hachreak ,

Sorry for not seeing this issue..... Did you mean when the teacher model is completely copied from the Student model (without keeping any teacher model weights in previous steps) the student model starts to degrade?