Janspiry / Image-Super-Resolution-via-Iterative-Refinement

Unofficial implementation of Image Super-Resolution via Iterative Refinement by Pytorch
Apache License 2.0
3.6k stars 469 forks source link

ddim sampling #72

Open betterze opened 2 years ago

betterze commented 2 years ago

Dear SR3 team,

Thank you for sharing this great implementation. I try the colab for 64->512 model, it works very good.

The ddpm sampling for 2k step is very slow, it takes around 7-8 mins per image. Will it be possible to use ddim sampling to increase the sampling speed?

I new to this filed, thank you for your advice.

Best Wishes,

Zongze

liangbingzhao commented 2 years ago

Have u implemented it? I tried DDIM but found it fail in SR. I wonder how could it be implemented

Janspiry commented 2 years ago

Sorry for late reply, I am only reproducing SR3 and have no plans to use DDIM.

ElliotQi commented 1 year ago

Have u implemented it? I tried DDIM but found it fail in SR. I wonder how could it be implemented

@liangbingzhao I've met the same question. DDIM failed on this repo. Have you implemented it?

Li-Qingyun commented 1 year ago

@betterze @ElliotQi @liangbingzhao Why did your experiments not work? I tried ddim two, the result are as follows:

diffusion method | SSIM(+) | PSNR(+) -- | -- | -- DDPM | 0.675 | 23.26 DDIM (20 steps) | 0.583 | 22.99

1

ddpm: 0_1_sr 0_1_sr_process ddim (20 sampling steps) 0_1_sr 0_1_sr_process gt: 0_1_hr

2

ddpm: 0_2_sr 0_2_sr_process ddim: 0_2_sr 0_2_sr_process gt: 0_2_hr

Li-Qingyun commented 1 year ago

I'm a newcomer, my codebase refer to ddim_sample() in denoising_diffusion_pytorch. The exps both use 64×64 -> 512×512 on FFHQ-CelebaHQ ckpt and sr_sr3_16_128.json config Are the ddim results better than ddpm ones? But my inference speed is really fast (about 90x). on RTX3070 (DDPM: 32.35 s/items DDIM: 2.68 items/s (The calculation speeds of each time are close, both are about 63 times/s))

liangbingzhao commented 1 year ago

I no longer focus on this problem. But ddim results usually are a little worser than ddpm ones.

Li-Qingyun commented 1 year ago

I no longer focus on this problem. But ddim results usually are a little worser than ddpm ones.

okkkk The results are indeed different. It seems that my implementation has worked.

Li-Qingyun commented 1 year ago

@Janspiry Does this repo consider reviewing a PR about supporting ddim for sr3? I can create a pr, but I require some supports about knowledge, experience and code reviewing.

wangzhen699 commented 1 year ago

@betterze @ElliotQi @liangbingzhao Why did your experiments not work? I tried ddim two, the result are as follows:

diffusion method SSIM(+) PSNR(+) DDPM 0.675 23.26 DDIM (20 steps) 0.583 22.99

1

ddpm: 0_1_sr 0_1_sr_process ddim (20 sampling steps) 0_1_sr 0_1_sr_process gt: 0_1_hr

2

ddpm: 0_2_sr 0_2_sr_process ddim: 0_2_sr 0_2_sr_process gt: 0_2_hr

您好,请问方便提供一下ddim的代码吗?我尝试了但是还有一些bug,非常感谢!

Li-Qingyun commented 1 year ago

@wangzhen699 You can refer to denoising-diffusion-pytorch. The code has been used in a colleague's private project.

927514606 commented 1 year ago

@betterze @ElliotQi @liangbingzhao Why did your experiments not work? I tried ddim two, the result are as follows: diffusion method SSIM(+) PSNR(+) DDPM 0.675 23.26 DDIM (20 steps) 0.583 22.99

1

ddpm: 0_1_sr 0_1_sr_process ddim (20 sampling steps) 0_1_sr 0_1_sr_process gt: 0_1_hr

2

ddpm: 0_2_sr 0_2_sr_process ddim: 0_2_sr 0_2_sr_process gt: 0_2_hr

您好,请问方便提供一下ddim的代码吗?我尝试了但是还有一些bug,非常感谢!

可以说一下bug在哪吗?

Li-Qingyun commented 1 year ago

没碰见过bug

wangzhen699 commented 1 year ago

@wangzhen699 You can refer to denoising-diffusion-pytorch. The code has been used in a colleague's private project.

没碰见过bug

您好,非常感谢您的回复,我按照您的建议调试了代码,但是还是没有成功。

whiteYi commented 1 year ago

没碰见过bug

你好,我想问下,在改为ddim后,模型是不是还要重新训练,即使是只改了推断过程,好像也不能直接使用之前sr3的权重了

Li-Qingyun commented 1 year ago

没碰见过bug

你好,我想问下,在改为ddim后,模型是不是还要重新训练,即使是只改了推断过程,好像也不能直接使用之前sr3的权重了

不是吧,我理解ddim是用来eval的吧

whiteYi commented 1 year ago

我理解的ddim是改变了DDPM中的inference,对p(x{t-1} | x{t},x{0})进行了改变,然后获取正向过程的一段子序列作为实际的采样次数来缩短inference的时间。所以我一直在对model中的diffusion文件进行改动。

Li-Qingyun commented 1 year ago

我理解的ddim是改变了DDPM中的inference,对p(x{t-1} | x{t},x{0})进行了改变,然后获取正向过程的一段子序列作为实际的采样次数来缩短inference的时间。所以我一直在对model中的diffusion文件进行改动。

嗯嗯,差不多

wangzhen699 commented 1 year ago

我理解的ddim是改变了DDPM中的inference,对p(x{t-1} | x{t},x{0})进行了改变,然后获取正向过程的一段子序列作为实际的采样次数来缩短inference的时间。所以我一直在对model中的diffusion文件进行改动。

你好,请问可以分享一下diffusion文件嘛?蟹蟹

wangzhen699 commented 1 year ago

@betterze @ElliotQi @liangbingzhao Why did your experiments not work? I tried ddim two, the result are as follows:

diffusion method SSIM(+) PSNR(+) DDPM 0.675 23.26 DDIM (20 steps) 0.583 22.99

1

ddpm: 0_1_sr 0_1_sr_process ddim (20 sampling steps) 0_1_sr 0_1_sr_process gt: 0_1_hr

2

ddpm: 0_2_sr 0_2_sr_process ddim: 0_2_sr 0_2_sr_process gt: ![0_2_hr](https://user-images.githubusercontent.com/79644233/219

@betterze @ElliotQi @liangbingzhao Why did your experiments not work? I tried ddim two, the result are as follows:

diffusion method SSIM(+) PSNR(+) DDPM 0.675 23.26 DDIM (20 steps) 0.583 22.99

1

ddpm: 0_1_sr 0_1_sr_process ddim (20 sampling steps) 0_1_sr 0_1_sr_process gt: 0_1_hr

2

ddpm: 0_2_sr 0_2_sr_process ddim: 0_2_sr 0_2_sr_process gt: 0_2_hr

请问您在加速采样的时候,有没有遇到最终采出的图片没有噪声但也没有超分效果的这种情况呢? 我的参数设置如下; 1.noise_level = torch.FloatTensor( [self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(x.device)

  1. pred_noise = self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level)
  2. x_start = self.predict_start_from_noise( x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level))
  3. alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next]

    sigma = eta ((1 - alpha / alpha_next) (1 - alpha_next) / (1 - alpha)).sqrt() c = (1 - alpha_next - sigma * 2).sqrt() noise = torch.randn_like(img) img = x_start alpha_next.sqrt() + \ c pred_noise + \ sigma noise

Li-Qingyun commented 1 year ago

@wangzhen699 没出现,我实现的很简单就work了,我是参考 pytorch denoising diffusion 的仓库实现的

kada0720 commented 1 year ago

@betterze @wangzhen699 @whiteYi May I ask if you have resolved this issue. I also want to use ddim to accelerate the sampling speed in SR3. But due to my knowledge, I don't know how to modify the code. Thank you very much and I look forward to hearing from you.

@ElliotQi I see that you have successfully used ddim to accelerate the sampling speed in SR3. May I ask how to modify the code? This has been bothering me for a long time. Thank you very much and I look forward to hearing from you~

yansonglee commented 1 year ago

@wangzhen699 @ElliotQi @Li-Qingyun With DDIM the sampling results are predominantly noise. Here is my code. How can I modify it to improve the results? Thank u so much.QAQ

x = x_in
shape = x.shape
b = shape[0]
img = torch.randn(shape, device=device)
ret_img = x
for i in tqdm(reversed(range(0, 1000,5)), desc='sampling loop time step', total=200):
         img = self.ddim_sample(img, torch.full( (b,), i, device=device, dtype=torch.long), condition_x=x)

DDIM_sample:

def ddim_sample(self, x, t, clip_denoised=True, repeat_noise=False, condition_x=None):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance ,x_start,eps = self.p_mean_variance(
            x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)

        eta=0.0
        alpha_bar = extract(self.alphas_cumprod, t, x.shape)
        alpha_bar_prev = extract(self.alphas_cumprod_prev, t, x.shape)
        sigma = eta * ((1 - alpha_bar / alpha_bar_prev) * (1 - alpha_bar_prev) / (1 - alpha_bar)).sqrt()
        c = (1 - alpha_bar_prev - sigma ** 2).sqrt()
        noise = torch.randn_like(x_start)

        img = x_start * alpha_bar_prev.sqrt() + \
                  c * eps + \
                  sigma * noise

        return img
Jakkiabc commented 9 months ago

@wangzhen699 @ElliotQi @Li-Qingyun With DDIM the sampling results are predominantly noise. Here is my code. How can I modify it to improve the results? Thank u so much.QAQ

x = x_in
shape = x.shape
b = shape[0]
img = torch.randn(shape, device=device)
ret_img = x
for i in tqdm(reversed(range(0, 1000,5)), desc='sampling loop time step', total=200):
         img = self.ddim_sample(img, torch.full( (b,), i, device=device, dtype=torch.long), condition_x=x)

DDIM_sample:

def ddim_sample(self, x, t, clip_denoised=True, repeat_noise=False, condition_x=None):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance ,x_start,eps = self.p_mean_variance(
            x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)

        eta=0.0
        alpha_bar = extract(self.alphas_cumprod, t, x.shape)
        alpha_bar_prev = extract(self.alphas_cumprod_prev, t, x.shape)
        sigma = eta * ((1 - alpha_bar / alpha_bar_prev) * (1 - alpha_bar_prev) / (1 - alpha_bar)).sqrt()
        c = (1 - alpha_bar_prev - sigma ** 2).sqrt()
        noise = torch.randn_like(x_start)

        img = x_start * alpha_bar_prev.sqrt() + \
                  c * eps + \
                  sigma * noise

        return img

hi I am trying displacing from ddpm to ddim,did you success?