ML-GSAI / EGSDE

Official implementation for "EGSDE: Unpaired Image-to-Image Translation via Energy-Guided Stochastic Differential Equations" (NIPS 2022)
195 stars 11 forks source link

Missing EMA model loading in DDPM #9

Open miquel-espinosa opened 1 year ago

miquel-espinosa commented 1 year ago

Hi. Thanks for this work! A couple of questions.

When running the file run_EGSDE.py and loading a DDPM model trained with ddim, I think that the code is loading the model, but not the EMA. I think that EMA tends to be used more in the diffusion models community, as it gives better results? (but I'm not sure).

Here is the code I had to add. File: runners/egsde.py, line 114. (In addition to importing the file)

            model = Model(config)
            states = torch.load(self.args.ckpt)
            model = model.to(self.device)
            model = torch.nn.DataParallel(model)
            model.load_state_dict(states[0], strict=True)

            # Load ema for ddim model
            if self.config.model.ema:
                ema_helper = EMAHelper(mu=self.config.model.ema_rate)
                ema_helper.register(model)
                ema_helper.load_state_dict(states[-1])
                ema_helper.ema(model)
            else:
                ema_helper = None

            model.eval()

Thanks,