LuChengTHU / dpm-solver

Official code for "DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps" (Neurips 2022 Oral)
MIT License
1.56k stars 121 forks source link

vanilla DDPM with cosine beta schedule, obtain results worse than DDIM #25

Open jiachenlei opened 1 year ago

jiachenlei commented 1 year ago

Hi, Thank you for your excellent codes and detailed documentation on how to incorporate DPM-solver in our own project!

I try to substitute DDIM with DPM-solver but fail to obtain comparable results.

Training details of my diffusion model: (1) Dataset: CelebA-HQ 256x256 (2) Vanilla DDPM ( L2 Loss, predict noise), T=1000, UNet, trained in raw pixel space no latent space used. (3) Beta schedules: Cosine schedule (according to Improved Denoising Diffusion Probabilistic Models)

Code snippet that uses DPM-solver in my project:

model = diffusion_model.model      # nn.Module, takes 256x256x3 images as input and predicts noise
betas = diffusion_model.betas      # cosine schedule

noise_schedule = NoiseScheduleVP(schedule='discrete',betas=betas)

model_kwargs = {}
model_fn = model_wrapper(
      model,
      noise_schedule,
      model_type="noise",  # or "x_start" or "v" or "score"
      model_kwargs=model_kwargs,
      guidance_type="uncond",
)

x_T = torch.randn((4, 3, 224, 224), device = device)
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
x_sample = dpm_solver.sample(
      x_T,
      steps=25,
      order=3,
      skip_type="time_uniform",
      method="singlestep",
      denoise_to_zero= False,
)
img = unnormalize_to_zero_to_one(x_sample)

Result sampled by DDIM after 500 steps: ddim_sample

Result sampled by DPM-solver after 25 steps with cosine schedule (schedule used in training) betas: dpm_solver_cosine

Result sampled by DPM-solver after 25 steps with linear schedule betas: dpm_solver

I have tried to tune parameters of DPM-solver, e.g. multi-step instead of single-step, more iterative steps, but neither works. Is this result from cosine schedule used when training diffusion model? Could you please give any suggestions on possible improvements? Thank you for your attention!

LuChengTHU commented 1 year ago

Hi @jasonrayshd ,

For cosine noise schedule, it may suffer terrible numerical issues for t near to T. In my previous implementations in DPM-Solver paper, I changed the start time from t_start=1.0 to t_start=0.9946 (by comparing the lambda, i.e. half-log-SNR). You can also try other t_start but do not use t_start=1.0. I will further support the cosine schedule in a more elegant way.

jiachenlei commented 1 year ago

Got it. Thank you for your great suggestions! Looking forward to your excellent work in the future!

codgodtao commented 1 year ago

@LuChengTHU @jasonrayshd i have tried cosine scheduler, and the result is better than ddim&ddpm, but i don't know the reason about it, and i'm not sure about is observation could be seen in High dimension data. here is the settings: `model = self.denoise_fn model_kwargs = {} guidance_scale = w

    noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas)

    model_fn = model_wrapper(
        model,
        noise_schedule,
        model_type="noise",  # or "x_start" or "v" or "score"
        model_kwargs=model_kwargs,
        guidance_type="classifier-free",
        condition=condition,
        unconditional_condition=unconditional_condition,
        guidance_scale=guidance_scale,
    )

    dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++",
                            correcting_x0_fn="dynamic_thresholding")

    x_sample = dpm_solver.sample(
        x_T,
        steps=20,
        order=2,
        skip_type="logSNR",
        method="multistep",
        denoise_to_zero=True
    )

`

LuChengTHU commented 1 year ago

Hi guys, I've fixed the numerical issue in the cosine beta schedule; please try the newest file for dpmsolver and see details in this function.

You can also try this script with the ImageNet64 (improved-DDPM checkpoint) example, which is a cosine schedule.