Closed matanat closed 9 months ago
Hi @matanat
Thanks for spotting this. It looks like it went unnoticed for so long because in most sampling users just use step()
and it seems reverse_step()
is only referenced in this tutorial.
Your solution looks right to me, and I agree with renaming the timestep variables to make the differences between this and thestep()
function clearer. Just CCing @SANCHES-Pedro to check he agrees.
If you're willing to make a PR to fix that would be great!
Hi @marksgraham I have a branch with the fix, but I don't have permission to create a PR.
Hi @matanat
I think you will need to fork the repo, make the changes to a branch on your fork, then create a PR from there - have you tried that?
There's a bug in how the DDIM scheduler computes the next alpha_prod in reversed_step() for the first timestamp (last in self.timesteps).
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
This line was just copied from the step() function (parameter names were copied, making it even more confusing), but this should treat the first alpha (last in the alphas_cumprod array). Otherwise, you get an exception when trying to invert an image all the way through the timestamps.I think this should be:
alpha_prod_t_next = self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod
And we should set the following:self.first_alpha_cumprod = torch.tensor(.0)
Or the last alpha_cumprod in the array (should probably be controllable)Wdyt? Should I create a pull request to fix this?