Open KiwiXR opened 11 months ago
To support my thoughts, I further modify the code in ddim_loop
into:
class NullInversion:
...
@torch.no_grad()
def ddim_loop(self, latent):
uncond_embeddings, cond_embeddings = self.context.chunk(2)
all_latent = [latent]
latent = latent.clone().detach()
for i in range(NUM_DDIM_STEPS):
t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
next_t = min(t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999) # copied from `self.next_step`
noise_pred = self.get_noise_pred_single(latent, next_t, cond_embeddings) # modified
latent = self.next_step(noise_pred, t, latent)
all_latent.append(latent)
return all_latent
Then I calculate the null-text inverted image's PSNR with the cat example image:
from skimage.metrics import peak_signal_noise_ratio
psnr = peak_signal_noise_ratio(image_gt, image_inv[0])
print(psnr)
The original version gives 29.56082923568291
, while the modified version gives 29.605481827030523
(greater is better).
Hope this can demonstrate my ideas.
Thanks for your excellent work!
While digging into the code of Null-Text Inversion, I found something confusing.
Firstly, according to the formula in your paper, the DDIM Inversion writes like this: where I assume $\epsilon_\theta(zt)$ represents $\epsilon\theta(z_t, t)$.
Then, I read the code provided in null_text_w_ptp.ipynb, I was somewhat confused by the implementation.
This block implements the inversion loop:
And this one implements a single step of inversion:
If I understand it correctly, in
ddim_loop
the variablenoise_pred
corresponds to $\epsilon_\theta(latent, t)$, which indicates thatlatent
is used as $z_t$. However, innext_step
, the passed in timestep (i.e., $t$) is renamed tonext_timestep
, and now the newtimestep
andnext_timestep
corresponds to $t-1$ and $t$.Therefore, I think the code actually gives: $$z_{t+1}=\sqrt{\frac{\alphat}{\alpha{t-1}}}z_t+\sqrt{\alpha_t}\cdot\Bigg(\sqrt{\frac{1}{\alphat}-1} - \sqrt{\frac{1}{\alpha{t-1}} - 1}\Bigg)\cdot\epsilon_\theta(z_t,t)$$
This is really confusing to me, please help me out!