pfriedri / wdm-3d

PyTorch implementation for "WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis" (DGM4MICCAI 2024)
https://pfriedri.github.io/wdm-3d-io
MIT License
50 stars 5 forks source link

Loss Problem #4

Closed kanydao closed 3 months ago

kanydao commented 3 months ago

Thank you for your excellent work, but I have some questions regarding the calculation of the model loss.

terms = {"mse_wav": th.mean(mean_flat((x_start_dwt - model_output) ** 2), dim=0)}

For the DDPM diffusion model, the loss is typically between the noise and the model output. Is there a mathematical basis for directly modifying the loss function this way, and does it compromise the optimization objective of the diffusion model?

pfriedri commented 3 months ago

@kanydao Thank you for the interest in our work.

There is different formulations for the DDPM. What we want to do ist to parameterize $q(x{t-1}|x{t},x_{0})=\mathcal{N}(\tilde{\mu}_t, \tilde{\beta}_t I)$. There is however different ways of modeling $\tilde{\mu}_t$, where the most obvious one would probably be to directly predict it using a neural network. Alternatively, the network could predict the noise $\epsilon(x_t,t)$ and define $\tilde{\mu}_t(x_t,t) = \frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon(x_t,t))$.

We define $\tilde{\mu}_t(x_t, \tilde{x}_0) = \frac{\sqrt{\bar{\alpha}_c}\beta_t}{1-\bar{\alpha}_t}\tilde{x}_0 + \frac{\sqrt{\alpha_t} (1 - \bar{\alpha}_c)}{1-\bar{\alpha}_t}x_t$ with $c=t-1$. This is mathematically the same, we just found that predicting $x_0$ worked better (in this setting) than predicting $\epsilon$.

You can find more information in chapter 2 of: Nichol, A. Q., & Dhariwal, P. (2021). Improved denoising diffusion probabilistic models. In International conference on machine learning (pp. 8162-8171). PMLR. Or in chapters 2 and 3 of: Ho, J., Jain, A., & Abbeel, P. (2020). Denoising diffusion probabilistic models. Advances in neural information processing systems, 33, 6840-6851.

I hope this answers your question.

kanydao commented 3 months ago

Thank you very much for your help. I have carefully read through the implementation of def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): and I understand your approach. Excellent work!