lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
7.97k stars 753 forks source link

Implement about `Conditional Augmentation` #44

Closed CiaoHe closed 2 years ago

CiaoHe commented 2 years ago

Hi Phil, I noticed that Imagen applies noise conditional augmentation for both two super-res models.

we corrupt the low-resolution image with the augmentation (corresponding to aug_level), and condition the diffusion model on aug_level. During training, aug_level is chosen randomly, while during inference, we sweep over its different values to find the best sample quality. In our case, we use Gaussian noise as a form of augmentation, and apply variance preserving Gaussian noise augmentation resembling the forward process used in diffusion models (Appendix A). The augmentation level is specified using aug_level ∈ [0, 1].

Based on my understanding, in train phrase, they first apply forward diffusion(with aug_level noise scale/level) to the low-res(x_lr ), and then they feed x_lr, z_t(hidden variable of high-res), t and aug_level to the Unet and optimize loss. In sample phrase they did the same thing. Like the following pseudo-code:

image

My doubt is: how to add aug_level into the Unet? The first idea proposed to use the noise-aug-condition is in Cascaded Diffusion Models for High Fidelity Image Generation (CDM)(https://arxiv.org/abs/2106.15282), and they suggest use another time-embedding for aug-level(refer to time s).

image

Recent papers very fewly mentioned this trick except Imagen, I just wonder how do you think about this trick, and is is necessary to combine this trick to current SuperResUnet model?

Best,

lucidrains commented 2 years ago

@CiaoHe Hi He Cao, and thank you for reviewing my code again 😃

So I am not 💯 sure on this, but I went with the exact same conditioning as done in DDPM. For the discrete case, I condition based on the time step, while in the continuous case, the log(snr) value (signal to noise ratio, so it correlates with the noise level)

Relevant lines for sampling are https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1472

And the lines for training are https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1579 and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1517

The Unet was also designed to accept two conditioning, one from the DDPM, and one for the augmentation noise level https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1080

Do let me know what you think! I have a feeling I may have needed to approach the cascading DDPM in the dalle2-pytorch repository the same way (I am not conditioning the super resoluting unets there, just blurring the output of the previous unet)

CiaoHe commented 2 years ago

wow, you have considered this, great! I think this part is exactly what the CDM did https://github.com/lucidrains/imagen-pytorch/blob/d4e12c0781385f42cbbf51128e5bc3d73abc0038/imagen_pytorch/imagen_pytorch.py#L1080-L1084

The imagen adopt this aug-conditional way but did this conditional trick is better than a simple blur, haha, I think it's worth comparing. Anyway, thanks for your reply, I think you're on the right way!

lucidrains commented 2 years ago

@CiaoHe yup, i think Imagen had the better super-resolution unets, so it is worth porting over this conditioning trick to dalle2-pytorch once we see some evidence from the open source community that Imagen-pytorch is working :smiley: