google / prompt-to-prompt

Apache License 2.0
2.98k stars 279 forks source link

Inconsistency between timestep and noise level in Null-Text Inversion? #64

Open KiwiXR opened 11 months ago

KiwiXR commented 11 months ago

Thanks for your excellent work!

While digging into the code of Null-Text Inversion, I found something confusing.

Firstly, according to the formula in your paper, the DDIM Inversion writes like this: image where I assume $\epsilon_\theta(zt)$ represents $\epsilon\theta(z_t, t)$.

Then, I read the code provided in null_text_w_ptp.ipynb, I was somewhat confused by the implementation.

This block implements the inversion loop:

class NullInversion:
    ...
    @torch.no_grad()
    def ddim_loop(self, latent):
        uncond_embeddings, cond_embeddings = self.context.chunk(2)
        all_latent = [latent]
        latent = latent.clone().detach()
        for i in range(NUM_DDIM_STEPS):
            t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
            noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
            latent = self.next_step(noise_pred, t, latent)
            all_latent.append(latent)
        return all_latent

And this one implements a single step of inversion:

class NullInversion:
    ...
    def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
        timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
        beta_prod_t = 1 - alpha_prod_t
        next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
        next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
        next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
        return next_sample

If I understand it correctly, in ddim_loop the variable noise_pred corresponds to $\epsilon_\theta(latent, t)$, which indicates that latent is used as $z_t$. However, in next_step, the passed in timestep (i.e., $t$) is renamed to next_timestep, and now the new timestep and next_timestep corresponds to $t-1$ and $t$.

Therefore, I think the code actually gives: $$z_{t+1}=\sqrt{\frac{\alphat}{\alpha{t-1}}}z_t+\sqrt{\alpha_t}\cdot\Bigg(\sqrt{\frac{1}{\alphat}-1} - \sqrt{\frac{1}{\alpha{t-1}} - 1}\Bigg)\cdot\epsilon_\theta(z_t,t)$$

This is really confusing to me, please help me out!

KiwiXR commented 11 months ago

To support my thoughts, I further modify the code in ddim_loop into:

class NullInversion:
    ...
    @torch.no_grad()
    def ddim_loop(self, latent):
        uncond_embeddings, cond_embeddings = self.context.chunk(2)
        all_latent = [latent]
        latent = latent.clone().detach()
        for i in range(NUM_DDIM_STEPS):
            t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
            next_t = min(t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)  # copied from `self.next_step`
            noise_pred = self.get_noise_pred_single(latent, next_t, cond_embeddings)  # modified
            latent = self.next_step(noise_pred, t, latent)
            all_latent.append(latent)
        return all_latent

Then I calculate the null-text inverted image's PSNR with the cat example image:

from skimage.metrics import peak_signal_noise_ratio
psnr = peak_signal_noise_ratio(image_gt, image_inv[0])
print(psnr)

The original version gives 29.56082923568291, while the modified version gives 29.605481827030523 (greater is better).

Hope this can demonstrate my ideas.