Closed hachreak closed 3 years ago
Maybe I found how to resolve:
@torch.no_grad()
def _update_student_model(self, keep_rate=0.996):
teacher_model_dict = self.model_teacher.state_dict()
prefix = ''
if comm.get_world_size() > 1:
prefix = 'module.'
new_student_dict = OrderedDict()
for key, value in self.model.state_dict().items():
tkey = key
if key.startswith('module.'):
tkey = key[7:]
if tkey in teacher_model_dict.keys():
new_student_dict[key] = (
teacher_model_dict[tkey] * (1 - keep_rate) + value * keep_rate
)
else:
raise Exception("{} is not found in teacher model".format(key))
self.model.load_state_dict(new_student_dict)
But I have a strange behavior: After the teacher -> student sync the performance of the student start to go down! Do you have any idea why? Because I do something similar to the burn up (when the teacher is fully sync with the student):
self._update_teacher_model(keep_rate=0.00)
Hello @hachreak ,
Sorry for not seeing this issue..... Did you mean when the teacher model is completely copied from the Student model (without keeping any teacher model weights in previous steps) the student model starts to degrade?
Hi everybody, I would try to update the student model instead to update the teacher. I copied the function and inverted self.model with self.model_teacher but in a distributed run I received this error because the key is not in teacher_model_dict.keys():
Do you know how should I do? Thanks a lot! :smile: