yangxiaofeng / rectified_flow_prior

Official code for paper: Text-to-Image Rectified Flow as Plug-and-Play Priors
69 stars 1 forks source link

Questions about the core code of RFDS_Rev_sd3. #3

Open yiboz2001 opened 2 weeks ago

yiboz2001 commented 2 weeks ago

Hi! Could I ask some questions about the core code of RFDS_Rev_sd3?

I am confused about lines 335 ~348 of RFDS_Rev_sd3.py.

sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
latents_noisy = sigmas * noise + (1.0 - sigmas) * latents            # pred noise
# iRFDS
noise_pred = self.forward_transformer(
    self.transformer,
    latents_noisy,
    timesteps,
    text_embeddings,
    text_embeddings_pooled,
    )
# https://github.com/huggingface/diffusers/blob/614d0c64e96b37740e14bb5c2eca8f8a2ecdf23e/examples/dreambooth/train_dreambooth_lora_sd3.py#L1481
new_latent = noise_pred * (-sigmas) + latents_noisy # the new latents
noise = noise_pred + new_latent
latents_noisy = sigmas * noise + (1.0 - sigmas) * latents            # pred noise

As mentioned in line 335, latents_noisy = sigmas * noise + (1.0 - sigmas) * latents. And we can get: new_latent = (1.0 - sigmas) * latents + sigmas * (noise - noise_pred) noise = (1.0 - sigmas) * latents + sigmas * (noise - noise_pred) + noise_pred

Specifically, I am confused about the meaning of noise, it seems not both gaussian noise and correct noisy latent. Is noise still gaussian noise? If so, then why? If not, why interpolate it to get the last latents_noisy?

yangxiaofeng commented 2 weeks ago

Hi, Thank you for your interest in our work! The noise in Line 335 is gaussian noise (corresponds to Line 4 of Algorithm1 in our paper). The noise in Line 346 is the optimized noise after 1 iRFDS step (corresponds to Line 6 of Algorithm1).

yiboz2001 commented 2 weeks ago

Thanks for reply!
The noise in my question specifically mean the noise in Line 346. Is the optimized noise still gaussian noise? I can not identify it due to the (1.0 - sigmas) * latents term of the derivation above. Could you provide more details? Thanks a lot!

yangxiaofeng commented 2 weeks ago

Thanks for the interpretation. I guess I understand your question now.

Well, it is not guaranteed to be gaussian noise but should be very close to gaussian. Reasons: The "latents" are generated from a VAE (or a diffusion model learning to generate a VAE output). Ideally, the VAE enforces the "latents" to be gaussian. So the sum of "latents" and "gaussian noise" should have 0 mean. But in real use-case, the VAEs of latent diffusion models are usually trained with not very high constraints. Therefore, it is hard to say the "latents" will be gaussian or not but it should be close to gaussian.

yiboz2001 commented 2 weeks ago

Another question is how to associate lines 345 ~ 346 with iRFDS process? I may not fully understand the mechanism of iRFDS. I would appreciate it if you could bring more theoretical explanation.

yangxiaofeng commented 2 weeks ago

Hi, to understand iRFDS, you may instead look at the non-SD3 codes (iRFDS.py & RFDS_Rev.py). The training objective of SD3 is slightly different from the original rectified flow paper. If you have any questions after that, feel free to raise it here.