inbarhub / DDPM_inversion

Official pytorch implementation of the paper: "An Edit Friendly DDPM Noise Space: Inversion and Manipulations". CVPR 2024.
https://inbarhub.github.io/DDPM_inversion/
MIT License
255 stars 12 forks source link

Application to other schedulers #1

Closed bonlime closed 2 months ago

bonlime commented 1 year ago

Hi @inbarhub First of all thanks for a very good and interesting paper, really enjoyed reading.

I wonder if it's possibly to apply the derived noise maps to schedulers other than DDPM/DDIM? For example have you tried substituting the noise maps in Euler Ancestral sampler? Since ddim/ddpm in general seems to produce lower quality results/requires larger number of steps

daiqing-qi commented 1 year ago

Hi @inbarhub First of all thanks for a very good and interesting paper, really enjoyed reading.

I wonder if it's possibly to apply the derived noise maps to schedulers other than DDPM/DDIM? For example have you tried substituting the noise maps in Euler Ancestral sampler? Since ddim/ddpm in general seems to produce lower quality results/requires larger number of steps

Hi @bonlime, Have you tried that? It seems that null-text inversion uses DDIM with predictoin = 'epsilon'. I am not sure if null-text inversion suppprts other options. I am curious as well

bonlime commented 1 year ago

@daiqing-qi No, I ended up not trying this paper. You're right that null-text uses DDIM inversion with prediction "epsilon", but I haven't yet seen the correct version of DDIM inversion. There is on in 🤗 diffusers, but it's incorrect and no one cares about it. I tried reimplementing it myself, it's more correct, but the math is still slightly off.

This is a proof that DDIM inversion in 🤗 is incorrect. With 0.5 strength going to noise from image (1) and back results in image (2) which looses saturation. My implementation results in image (3), which looks better at first
(1) image (2) image (3) image

But after applying the inversion on the same image for multiple times 🤗 results in image (4) and my version in image (5) which shows that it's also incorrect but in different direction. The DPM++ inverted scheduler in 🤗 suffers from the same exact problems (4) image (5) image

daiqing-qi commented 1 year ago

@daiqing-qi No, I ended up not trying this paper. You're right that null-text uses DDIM inversion with prediction "epsilon", but I haven't yet seen the correct version of DDIM inversion. There is on in 🤗 diffusers, but it's incorrect and no one cares about it. I tried reimplementing it myself, it's more correct, but the math is still slightly off.

This is a proof that DDIM inversion in 🤗 is incorrect. With 0.5 strength going to noise from image (1) and back results in image (2) which looses saturation. My implementation results in image (3), which looks better at first (1) image (2) image (3) image

But after applying the inversion on the same image for multiple times 🤗 results in image (4) and my version in image (5) which shows that it's also incorrect but in different direction. The DPM++ inverted scheduler in 🤗 suffers from the same exact problems (4) image (5) image

Hi @bonlime, thanks for you reply! It is very helpful. I think applying the inversion on the same image for multiple times leads to another image is reasonable as the inversion only makes the reconstructed image look similar, while the invisible erroes can accumulate. May I ask if you could share your code/implementation of the DDIM Inversion? Thanks!

bonlime commented 1 year ago

Sure, the code is not much of a secret :)

Main code is here ```python # The only difference from :diffusers: is option `reverse` # upd. and support for timestep as Tensor in .step method class DDIMSchedulerReversible(DDIMScheduler): def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator: torch.Generator = None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, reverse: bool = False, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. generator: random number generator. variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://arxiv.org/abs/2210.05559) return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] if isinstance(timestep, torch.Tensor) and timestep.numel() > 1: alpha_prod_t_prev = torch.where( prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod, ) # add dimensions for broadcasting alpha_prod_t_prev = alpha_prod_t_prev.view(-1, 1, 1, 1).to(sample.device) alpha_prod_t = alpha_prod_t.view(-1, 1, 1, 1).to(sample.device) else: alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod if reverse: alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # predict V model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`" # noqa: E501 ) # 4. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) if eta > 0: variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) else: std_dev_t = 0 if use_clipped_model_output: # the model_output is always re-derived from the clipped x_0 in Glide model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 device = model_output.device if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: if device.type == "mps": # randn does not work reproducibly on mps variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) variance_noise = variance_noise.to(device) else: variance_noise = torch.randn( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) ```

It also requires sampling using flipped timesteps

self.inverse_scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps_ = self.get_timesteps(num_inference_steps, strength, device)
for t in self.inverse_scheduler.timesteps.flip(0)[:num_inference_steps_]:

I think applying the inversion on the same image for multiple times leads to another image is reasonable

The problem is not that we get another image, but rather that saturation changes. This effect is very consistent in both 🤗 and my implementation. Saturation always decreases for them and always increases for my version. Maybe carefully solving the inversion math could work, but I didn't have time to do that, maybe you would :)

inbarhub commented 1 year ago

Hi,

From the equation of "Euler Ancestral", it seems that you can extract the noise in the same way we did in DDPM. However, we haven't tried this.