HuiZhang0812 / DiffusionAD

148 stars 16 forks source link

Is the whole network 2-stage training? #12

Closed HSIU-HUA closed 10 months ago

HSIU-HUA commented 10 months ago

I mean, should I train the diffusion model first and after that train diffusion model + segmentation network together ? (2-stage training) Or, I just train diffusion model + segmentation network together (1-stage training)

NilsB98 commented 10 months ago

Hi @HSIU-HUA, the training is done in one step. You can see it in the "Algorithm 1" Pseudocode which is included in the paper.

Where you can see the following:

# Design noise prediction loss
loss_noise = desgin_noise_loss(eps, pred_eps)

#predict flawless approximation via predicted eps
pred_images = sqrt(1 / alpha_cumprod(t)) * img_crpt - sqrt(1 / alpha_cumprod(t) - 1) * pred_eps

# Predict anomaly mask
pred_mask = Segmentation_sub_network(cat((images, pred_images),dim=1)) # [B, 1, H, W]

# Design segmentation loss
loss_mask = desgin_mask_loss(gt_mask, pred_mask)

# Design total loss
loss = loss_noise + loss_mask

Here we can see that both losses are combined and the final loss will then be used for optimization. I don't see any other specification about first training only the diffusion model, so I dont think they have any separate procedure than this training step.

Best, Nils

HSIU-HUA commented 10 months ago

@NilsB98 , Hi Thank you for your reply. Have you reimplemented this paper successfully?

Sincerely, HSIU-HUA

NilsB98 commented 10 months ago

Not yet, but I'll try it in the coming days. I'll leave you a comment here if I'm successful :)

HSIU-HUA commented 10 months ago

@NilsB98 thank you so much ! I also tried to reimplement this, but I couldn't get the good results as paper described, so please share your reimplemented results no matter whether you're successful or failed(We can share our results with each other). Thank you so much in advance.

HuiZhang0812 commented 10 months ago

Thank you for your patience. DiffusionAD is an end-to-end architecture. We have open-sourced the code.