ShiqiYu / OpenGait

A flexible and extensible framework for gait recognition. You can focus on designing your own models and comparing with state-of-the-arts easily with the help of OpenGait.
665 stars 154 forks source link

Some confusion when using GradScaler with multiple optimizer #147

Closed enemy1205 closed 8 months ago

enemy1205 commented 11 months ago

If I want to use multiple optimizers and use AMP ,I inherited and rewrote the other relevant code but ran into problems with the gradient backwards here Original:

        if self.engine_cfg['enable_float16']:
            self.Scaler.scale(loss_sum).backward()
            self.Scaler.step(self.optimizer)
            scale = self.Scaler.get_scale()
            self.Scaler.update()
            # Warning caused by optimizer skip when NaN
            # https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5
            if scale != self.Scaler.get_scale():
                self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format(
                    scale, self.Scaler.get_scale()))
                return False

Rewritten:

        if self.engine_cfg['enable_float16']:
            self.Scaler.scale(loss_sum).backward()
            # self.optimizer as a list consist of different optimizer 
            for optimizer in self.optimizer:
                self.Scaler.step(optimizer)
            scale = self.Scaler.get_scale()
            self.Scaler.update()
            # Warning caused by optimizer skip when NaN
            # https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5
            if scale != self.Scaler.get_scale():
                self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format(
                    scale, self.Scaler.get_scale()))
                return False

then in if scale != self.Scaler.get_scale() ,scale will always be the half of self.Scaler.get_scale() I know this is a problem in my code, but I don't know how to use self.Scaler correctly. I hope you can guide the correct modification examples , thanks!

github-actions[bot] commented 9 months ago

Stale issue message