cyclomon / UNSB

Official Repository of "Unpaired Image-to-Image Translation via Neural Schrödinger Bridge" (ICLR 2024)
MIT License
169 stars 8 forks source link

Using the hyperparameter weight lambda_SB #15

Closed juanprietob closed 7 months ago

juanprietob commented 9 months ago

The related code is here -> https://github.com/cyclomon/UNSB/blob/main/models/sb_model.py#L307-L315

    if self.opt.lambda_SB > 0.0:
            XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B], dim=1)
            XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2], dim=1)
            bs = self.opt.batch_size
            ET_XY    = self.netE(XtXt_1, self.time_idx, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time_idx, XtXt_2).reshape(-1), dim=0)
            self.loss_SB = -(self.opt.num_timesteps-self.time_idx[0])/self.opt.num_timesteps*self.opt.tau*ET_XY
            self.loss_SB += self.opt.tau*torch.mean((self.real_A_noisy-self.fake_B)**2)

self.loss_SB += self.opt.tau*torch.mean((self.real_A_noisy-self.fake_B)**2)

Shouldn't the self.opt.tau be replaced by the weight self.opt.lambda_SB?

cyclomon commented 7 months ago

Hi, Thank you for your comment. I revised the source code.

in Line 327, self.loss_SB -> self.opt.lambda_SB*self.loss_SB