LeiaLi / SRDiff

171 stars 19 forks source link

Implementation on own Dataset #5

Closed Frankie91 closed 1 year ago

Frankie91 commented 1 year ago

Hello,

I would be interested in training your model on my own dataset with another LR encoder, but I'm struggling to understand exactly how to carry out the training step (with particular regard to the loss calculation) from the information and code you provide. Specifically, could you clarify what is the role of the following two functions in SRDiff/models/diffusion.py:

def p_losses(self, x_start, t, cond, img_lr_up, noise=None):
    noise = default(noise, lambda: torch.randn_like(x_start))
    x_tp1_gt = self.q_sample(x_start=x_start, t=t, noise=noise)
    x_t_gt = self.q_sample(x_start=x_start, t=t - 1, noise=noise)
    noise_pred = self.denoise_fn(x_tp1_gt, t, cond, img_lr_up)
    x_t_pred, x0_pred = self.p_sample(x_tp1_gt, t, cond, img_lr_up, noise_pred=noise_pred)

    if self.loss_type == 'l1':
        loss = (noise - noise_pred).abs().mean()
    elif self.loss_type == 'l2':
        loss = F.mse_loss(noise, noise_pred)
    elif self.loss_type == 'ssim':
        loss = (noise - noise_pred).abs().mean()
        loss = loss + (1 - self.ssim_loss(noise, noise_pred))
    else:
        raise NotImplementedError()
    return loss, x_tp1_gt, noise_pred, x_t_pred, x_t_gt, x0_pred

def q_sample(self, x_start, t, noise=None):
    noise = default(noise, lambda: torch.randn_like(x_start))
    t_cond = (t[:, None, None, None] >= 0).float()
    t = t.clamp_min(0)
    return (
                   extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                   extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
           ) * t_cond + x_start * (1 - t_cond)

My understanding is that they also enable us to compute the required x_t input, but it's not clear to me neither how nor why, given that in the description of the training phase provided in the paper only x_r and x_e are mentioned as inputs together with the LR-HR pair.

Thank you in advance for your attention!

Frankie91 commented 1 year ago

Found answer elsewhere and managed to successfully implement the model

lanhoulllllllllll commented 1 year ago

@Frankie91 hello can I have a communication with you , I have some problems about code want to ask you