DeminYu98 / DiffCast

[CVPR 2024] Official implementation of "DiffCast: A Unified Framework via Residual Diffusion for Precipitation Nowcasting"
GNU General Public License v3.0
45 stars 7 forks source link

About the loss of image quality after adding diffusion #4

Open Spring-lovely opened 4 months ago

Spring-lovely commented 4 months ago

Hi, Dear author Thank you so much for your open source work. I have the following questions when running the code, I hope you can take some time to answer them. I'm 6 frames predict 6 frames.

1715245073046

The first line of the following picture is the input frame, the second line is the label, the third line is the final prediction, the fourth line is backbone output, and now the diff is added, but the image quality is significantly decreased (see line 3), I do not know whether there is a problem with the function of the denoising part, please help to see.

Only add the following functions to diffcast.py ` def predict(self, frames_in, compute_loss=False, frames_gt=None, **kwargs): T_out = default(kwargs.get('T_out'), 6)

    if compute_loss:
        B, T_in, c, h, w = frames_in.shape
        device = self.device

        backbone_output, backbone_loss = self.backbone_net.predict(frames_in, frames_gt=frames_gt,
                                                                   compute_loss=compute_loss, **kwargs)

        residual = frames_gt - backbone_output
        global_ctx, local_ctx = self.ctx_net.scan_ctx(torch.cat((frames_in, backbone_output), dim=1))

        pre_frag = frames_in
        pre_mu = None
        pred_ress = []
        diff_loss = 0.
        t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
        for frag_idx in range(T_out // T_in):
            mu = backbone_output[:, frag_idx * T_in : (frag_idx + 1) * T_in]
            res = residual[:, frag_idx * T_in : (frag_idx + 1) * T_in]

            cond = pre_frag - pre_mu if pre_mu is not None else torch.zeros_like(pre_frag)
            res_pred, noise_loss = self.p_losses(res, t, cond=cond, ctx=global_ctx if frag_idx > 0 else local_ctx,
                                                 idx=torch.full((B,), frag_idx, device=device, dtype=torch.long))
            diff_loss += noise_loss

            frag_pred = res_pred + mu
            pre_frag = frag_pred
            pre_mu = mu

        alpha = torch.tensor(0.5)
        loss = (1 - alpha) * backbone_loss + alpha * diff_loss / 3.

        #backbone_output = self.unnormalize(backbone_output)
        return backbone_output, loss
    else:
        pred, mu, y = self.sample(frames_in=frames_in, T_out=T_out)
        loss = None
        backbone_loss = None
        diff_loss = None

        # return pred, mu, y, loss
        return pred, mu

def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, ctx=None, idx=None, cond=None):
    b, _, c, h, w = x_start.shape

    noise = default(noise, lambda: torch.randn_like(x_start))

    # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
    offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)

    if offset_noise_strength > 0.:
        offset_noise = torch.randn(x_start.shape[:2], device = self.device)
        noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')

    # noise sample
    x = self.predict_v(x_start=x_start, t=t, noise=noise)
    model_out = self.model(x, t, cond=cond, ctx=ctx, idx=idx)

    if self.objective == 'pred_noise':
        target = noise
    elif self.objective == 'pred_x0':
        target = x_start
    elif self.objective == 'pred_v':
        v = self.predict_v(x_start, t, noise)
        target = v
    else:
        raise ValueError(f'unknown objective {self.objective}')

    loss = F.mse_loss(model_out, target, reduction = 'none')
    loss = reduce(loss, 'b ... -> b', 'mean')

    loss = loss * extract(self.loss_weight, t, loss.shape)
    return model_out, loss.mean()

`

DeminYu98 commented 4 months ago

Hi, thanks for following our work. From your code, I notice that you utilize the model_out as the residual_pred, which is an error for diffusion training strategy. You can find a demo achievement of p_losses func in Line 767 at this. Typically, diffusion models do not generate target during training. Finally, you can determine if the diffusion model is being properly trained by observing the trend of the diff_loss during the training process.

Yager-42 commented 3 months ago

Hi, thanks for following our work. From your code, I notice that you utilize the model_out as the residual_pred, which is an error for diffusion training strategy. You can find a demo achievement of p_losses func in Line 767 at this. Typically, diffusion models do not generate target during training. Finally, you can determine if the diffusion model is being properly trained by observing the trend of the diff_loss during the training process.

@DeminYu98 , But from the code @Spring-lovely post, it seems the diffusion model's output is residual_pred, since the p_loss's target is residual itself, it looks the same from the link you post, if the model_out is not residual_pred, then what it is, if you can give me a more specific case or point out which line should be modified to make things right, i will be very grateful. Thanks for your time.

DeminYu98 commented 3 months ago

@Yager-42 Well, thanks for pointing out your confusion. I think it would be better to considerate this question from the basic theory of DDPM.

Yager-42 commented 3 months ago

@DeminYu98 Thanks for your reply, i do understand now, best wishes to you

Spring-lovely commented 2 months ago

Hello, I apologize for disturbing you: After your last guidance, during the training phase of calculating the loss function, cond = frame_in - gt, UNet predicts x0, and I finally get the prediction result as shown in the third row. I find that my final result is more similar to the input (first row) rather than the ground truth (second row). Is there something wrong with my code? Although it indeed looks more refined than the backbone, I found that after adding diffusion, metrics such as CSI and HSS have decreased compared to the backbone. Where might the problem be? Thank you. Looking forward to your reply. 1720511012549 image

DeminYu98 commented 1 month ago

@Spring-lovely I apologize for the delayed response. From a visual standpoint, your prediction results appear normal. To make a more comprehensive assessment, it would be helpful to examine the training loss and residual visualizations. Given that I'm not fully aware of your specific training process, I'm unable to determine if there are any issues definitively. Regarding the metrics you mentioned, I noticed that you're only predicting 6 frames. For such short-term predictions, deterministic backbones indeed tend to perform better. Additionally, I'm curious whether your 6-frame prediction is based on an autoregressive approach or as a single segment prediction? The segment size can significantly impact DiffCast's performance, as we've noted in our paper's appendix. If you have any further questions or need clarification on any point, please don't hesitate to ask.

sqfoo commented 1 week ago

Thanks, @Spring-lovely for providing the loss function. However, I have a few suggestions:

  1. We need to set the auto_normalize variable in GaussianDiffusion class to False for the loss function to work correctly.
  2. x in p_loss on the code above should be updated to x = self.q_sample(x_start=x_start, t=t, noise=noise) as suggested in previous discussions.
  3. As the task of @Spring-lovely seems to be a 1-in-1-out task whereas the paper describes an autoregressive 1-in-4-out task. This means that the loss function provided above does not fit with autoregressive task, potentially leading to the discontinuity between segments. Thus, I recommended using: cond = [:, (frag_idx-1) * T_in : (frag_idx) * T_in] if pre_mu is not None else torch.zeros_like(pre_frag).

I've attached the final loss function below. I would appreciate any feedback or suggestions from anyone.

def compute_loss(self, frames_in, frames_gt):
    compute_loss = True
    B, T_in, c, h, w = frames_in.shape
    T_out = frames_gt.shape[1]
    device = frames_in.device

    backbone_output, backbone_loss = self.backbone_net.predict(frames_in, frames_gt=frames_gt, compute_loss=compute_loss)

    frames_in = self.normalize(frames_in)
    backbone_output = self.normalize(backbone_output)
    frames_gt = self.normalize(frames_gt)

    residual = frames_gt - backbone_output
    global_ctx, local_ctx = self.ctx_net.scan_ctx(torch.cat((frames_in, backbone_output), dim=1))

    pre_frag = frames_in
    pre_mu = None
    pred_ress = []
    diff_loss = 0.
    t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
    for frag_idx in range(T_out // T_in):
        mu = backbone_output[:, frag_idx * T_in : (frag_idx + 1) * T_in]
        res = residual[:, frag_idx * T_in : (frag_idx + 1) * T_in]

        cond = pre_frag - pre_mu if pre_mu is not None else torch.zeros_like(pre_frag)
        res_pred, noise_loss = self.p_losses(res, t, cond=cond, ctx=global_ctx if frag_idx > 0 else local_ctx,
                                                idx=torch.full((B,), frag_idx, device=device, dtype=torch.long))
        diff_loss += noise_loss

        pre_frag = frames_gt[:, frag_idx * T_in : (frag_idx + 1) * T_in]
        pre_mu = mu
    diff_loss /= (T_out // T_in)

    alpha = torch.tensor(0.5)
    loss = (1 - alpha) * backbone_loss + alpha * diff_loss
    return loss

# Reference: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
@autocast(enabled = False)
def q_sample(self, x_start, t, noise = None):
    noise = default(noise, lambda: torch.randn_like(x_start))

    return (
        extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
        extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
    )

def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, ctx=None, idx=None, cond=None):
    b, _, c, h, w = x_start.shape

    noise = default(noise, lambda: torch.randn_like(x_start))

    # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
    offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)

    if offset_noise_strength > 0.:
        offset_noise = torch.randn(x_start.shape[:2], device = self.device)
        noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')

    # noise sample
    x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating

    model_out = self.model(x, t, cond=cond, ctx=ctx, idx=idx)

    if self.objective == 'pred_noise':
        target = noise
    elif self.objective == 'pred_x0':
        target = x_start
    elif self.objective == 'pred_v':
        v = self.predict_v(x_start, t, noise)
        target = v
    else:
        raise ValueError(f'unknown objective {self.objective}')

    loss = F.mse_loss(model_out, target, reduction = 'none')
    loss = reduce(loss, 'b ... -> b', 'mean')

    loss = loss * extract(self.loss_weight, t, loss.shape)
    return model_out, loss.mean()