When running the file run_EGSDE.py and loading a DDPM model trained with ddim, I think that the code is loading the model, but not the EMA. I think that EMA tends to be used more in the diffusion models community, as it gives better results? (but I'm not sure).
Here is the code I had to add. File: runners/egsde.py, line 114. (In addition to importing the file)
model = Model(config)
states = torch.load(self.args.ckpt)
model = model.to(self.device)
model = torch.nn.DataParallel(model)
model.load_state_dict(states[0], strict=True)
# Load ema for ddim model
if self.config.model.ema:
ema_helper = EMAHelper(mu=self.config.model.ema_rate)
ema_helper.register(model)
ema_helper.load_state_dict(states[-1])
ema_helper.ema(model)
else:
ema_helper = None
model.eval()
Hi. Thanks for this work! A couple of questions.
When running the file
run_EGSDE.py
and loading a DDPM model trained with ddim, I think that the code is loading the model, but not the EMA. I think that EMA tends to be used more in the diffusion models community, as it gives better results? (but I'm not sure).Here is the code I had to add. File: runners/egsde.py, line 114. (In addition to importing the file)
Thanks,