claroche-r / FastDiffusionEM

17 stars 1 forks source link

(FastEM/EM)-PiGDM with general degradations #4

Closed man-sean closed 6 months ago

man-sean commented 6 months ago

Thank you for implementing PiGDM; it is the only resource I know of that implements the noisy version!

As far as I can tell, the code only supports deblurring tasks: https://github.com/claroche-r/FastDiffusionEM/blob/93abbfba1f41767952239a978be169b59f5da355/guided_diffusion/gaussian_diffusion.py#L634)

Can you help me understand how we can generalize it to support the different operators supported by the original DPS code?

claroche-r commented 6 months ago

Hi man-sean,

You only have to compute the guidance associated with your inverse problem. In PiGDM, the guidance term is expressed as follows (Equation (7) of their paper): $$\left((y - H\widehat{x}_0)^T(r_t^2 H H^T + \sigma^2 I)^{-1}H \left(\frac{\partial\widehat{x}_0}{\partial x_t}\right) \right)^T$$ So agem.deblurring_guidance correspond to this quantity: $$H^T (r_t^2 H H^T + \sigma^2 I)^{-1}(y - H\widehat{x}_0).$$

For deconvolution, we used the fact that convolution is diagonal in Fourier to compute this quantity. For many degradation operators you might be able to compute this quantity without problems. For super-resolution, it is kind of tricky but you can use this code as a starting point: https://github.com/yuanzhi-zhu/DiffPIR/blob/592826b9db9075763e2ce70d085b14638fffd890/utils/utils_sisr.py#L65-L75

Hope this helps.

Best, Charles Laroche

man-sean commented 6 months ago

Thank you for the detailed answer!

I chose to use the degradation modules defined in DDRM: https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py

They are implemented using (efficient) SVD, which let us compactly compute the above quantity: $$V \left( \Sigma + \left( r_t^2 \Sigma^2 + \sigma_y^2 I \right)^{-1} \right) U^\top$$

It works like a charm! Thanks again