NVlabs / LSGM

The Official PyTorch Implementation of "LSGM: Score-based Generative Modeling in Latent Space" (NeurIPS 2021)
Other
340 stars 49 forks source link

reconstruct image through diffusion #4

Closed wangherr closed 2 years ago

wangherr commented 2 years ago

To reconstruct image, there are two ways:

1. origin_img -> [encoder] -> latent  -> [decoder] -> recon_img
2. origin_img -> [encoder] -> latent  -> [diffuse] -> noise -> [reverse diffuse]-> latent -> [decoder] -> recon_img

Apparently, the first way could get the reconstructed image which is the same as origin image.

While in the second way, will the origin_img and the recon_img be same? I'm not sure that there are some wrongs in my code or this way couldn't do it.

Thanks!

arash-vahdat commented 2 years ago

Yes, the first scenario should reconstruct the image as it corresponds to the reconstruction loop of our base VAE. In the second scenario, you can also reconstruct the original image well, if you use the probability flow ODE for these steps: latent -> [diffuse] -> noise -> [reverse diffuse]-> latent

The probability flow ODE is a deterministic mapping that maps latent variables to noise and back. It is guaranteed to reconstruct the original images if you use a sufficiently small step size when solving the ODEs.

You won't be able to reconstruct the input image if you use the stochastic forward and reverse diffusion process.

wangherr commented 2 years ago

【Q1】 does “a sufficiently small step size” means that setting "arg.ode_eps" small, for example the default value 0.00001 ?


【Q2】 I tried as follow: (cifar10, pretrained weight, the default args)

logits, all_log_q, all_eps = vae(x)
eps = vae.concat_eps_per_scale(all_eps)[0]

# set noise = eps
eps, nfe, time_ode_solve = diffusion_cont.sample_model_ode(dae, num_samples, shape, ode_eps, ode_solver_tol, enable_autocast, temp, noise=eps)

decomposed_eps = vae.decompose_eps(eps)
recon_img = vae.sample(x.shape[0], 1., decomposed_eps, args.autocast_eval)

x: origins

recon_img: recons

have I combined the code wrong?

Thanks!

arash-vahdat commented 2 years ago

[Q1] In general, I would recommend setting both ode_eps and ode_solver_tol to small values such as (1e-4, 1e-5) respectively.

[Q2] The issue is thatnoise is set to eps. Unfortunately, this is not correct. You should obtain noise by solving the generative ODE from t=ode_eps to t=1 starting from eps. So, the steps will look very similar to yours with an additional new step:

eps = vae.concat_eps_per_scale(all_eps)[0]

# missing step
noise = diffusion_cont.reverse_generative_ode(dae, eps, ... )

eps, nfe, time_ode_solve = diffusion_cont.sample_model_ode(dae, num_samples, shape, ode_eps, ode_solver_tol, enable_autocast, temp, noise)

decomposed_eps = vae.decompose_eps(eps)
recon_img = vae.sample(x.shape[0], 1., decomposed_eps, args.autocast_eval)

Unfortunately, we currently don't have reverse_generative_ode implemented exactly the way you need. But, we are solving this reverse generative ODE for computing log-likelihood in compute_ode_nll (i.e., likelihood computation starts from eps and solves the generative ODE from t=ode_eps to t=1). You can modify this function and obtain noise, stored at x_t0 in this line.

I hope this was helpful.

wangherr commented 2 years ago

Thanks, I have succeed in this way.

eps = vae.concat_eps_per_scale(all_eps)[0]

# missing step
noise = diffusion_cont.reverse_generative_ode(dae, eps, ... )

eps, nfe, time_ode_solve = diffusion_cont.sample_model_ode(dae, num_samples, shape, ode_eps, ode_solver_tol, enable_autocast, temp, noise)

decomposed_eps = vae.decompose_eps(eps)
recon_img = vae.sample(x.shape[0], 1., decomposed_eps, args.autocast_eval)

origin_img: origins

recon_img: recons-fode-1e-05-1e-05

While, I notice that

eps = vae.concat_eps_per_scale(all_eps)[0]

noise = diffusion_cont.reverse_generative_ode(dae, eps, ... )

# diffusion_disc.run_denoising_input_diffusion() is modified by  diffusion_disc.run_denoising_diffusion()
# just replace the `run_denoising_diffusion() x_noisy = torch.randn()`  by `x_noisy = eps`
eps = diffusion_disc.run_denoising_input_diffusion(eps, dae, x.shape[0], diff_steps=diff_steps, temp=temp, enable_autocast=enable_autocast, is_image=False, prior_var=prior_var)

decomposed_eps = vae.decompose_eps(eps)
recon_img = vae.sample(x.shape[0], 1., decomposed_eps, args.autocast_eval)

origin_img: origins

recon_img: recons-fb1000-1e-05-1e-05


Could you please share the idea to reconstruct image by using method based on “diffusion_disc.run_denoising_diffusion()"

Thanks!

arash-vahdat commented 2 years ago

run_denoising_diffusion uses the SDE sampling (or equivalently the DDPM formulation) to generate samples. In this formulation, the generated eps is almost conditionally independent from the starting x_noisy. That's why you see very diverse images in recon_img in the second case.

In contrast, the ODE sampler has 1:1 mapping between the generated eps and starting variable. That's why you can perfectly reconstruct eps.

If you like to do something in between, I'd say check the DDIM sampler. It has a variance parameter that when set to zero becomes like an ODE sampler. But when it is non-zero it injects noise and it makes reconstruction stochastic.