ge-xing / Diff-UNet

Diff-UNet: A Diffusion Embedded Network for Volumetric Segmentation. (using diffusion for 3D medical image segmentation)
Apache License 2.0
149 stars 21 forks source link

About Train #19

Closed jyx1244204317 closed 1 year ago

jyx1244204317 commented 1 year ago
def training_step(self, batch):
    image, label = self.get_input(batch)
    x_start = label

    x_start = (x_start) * 2 - 1
    x_t, t, noise = self.model(x=x_start, pred_type="q_sample")
    pred_xstart, pred_y = self.model(x=x_t, step=t, image=image, pred_type="denoise")

    loss_dice = self.dice_loss(pred_xstart, label)
    loss_bce = self.bce(pred_xstart, label)

    pred_xstart = torch.sigmoid(pred_xstart)
    loss_mse = self.mse(pred_xstart, label)

    loss = loss_dice + loss_bce + loss_mse

In the training phase, the noise generated from q_sample, the next step of prediction is the loss calculated with label. In my understanding, xt should calculate the loss with xt-1, and then generate x0 for the final target detection. But here why not calculate the loss with the next step, but directly calculate with the standard label?

920232796 commented 1 year ago

For the segmentation task, diffusion model need to predict x_0, rather than noise.

jyx1244204317 commented 1 year ago

Thank you!