lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.11k stars 768 forks source link

Benefit of CrossEmbedLayer compared to just a single Conv2D #263

Closed Mut1nyJD closed 2 years ago

Mut1nyJD commented 2 years ago

Hi @lucidrains

I started to compare the UNet architectures between the one in Imagen compared to the one in denoising-diffusion. A lot of it is due to the text embedding so I ignored most of that. The only three major differences I could see was

  1. CrossEmbed is used as initial_conv and as Down sampler conv in Down ResetNet blocks
  2. There is a predownsample branch for memory efficiency in the resnet structure
  3. Few other bits like fmaps aggregation for upsampling, etc, 4 Different Upsampler uses PixelShuffle not really sure if that adds much. Never experienced the checkerboard artefact so far. 5.SiLU is used instead of GELU in time embedding

I understand the reason for 2 , 3 and 4 but I don't really understand the benefit of 1.

Could you maybe explain the reasoning for this? And what the benefit of it is compared of just a single 2D Convolution layer? Thank you!

lucidrains commented 2 years ago

@Mut1nyJD :wave: so that comes from this paper https://arxiv.org/abs/2108.00154 , but you are right, the improvements could be negligible

the biggest flexibility with the unets here than the ddpm-pytorch repo is the option to use full attention at any stage in the unet, as well as to customize the depth. other small details include a scale on the skip connections, which some papers report better convergence for when doing upresoluting

lucidrains commented 2 years ago

also have the option for a better type of squeeze excitation here within the resnet blocks (global context attention)

Mut1nyJD commented 2 years ago

@Mut1nyJD 👋 so that comes from this paper https://arxiv.org/abs/2108.00154 , but you are right, the improvements could be negligible

the biggest flexibility with the unets here than the ddpm-pytorch repo is the option to use full attention at any stage in the unet, as well as to customize the depth. other small details include a scale on the skip connections, which some papers report better convergence for when doing upresoluting

Ah ok thank you for the paper reference. Well I am doing some ablation runs I've added it into my ddpm checkout as soon as I got something I am happy to share.

You mean full attention instead of just linear attention? Also in the up blocks hmm must have missed that.

Thank you for your detailed answer!

lucidrains commented 2 years ago

@Mut1nyJD yea no problem

yet another difference is https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py#L1037 , taken from https://arxiv.org/abs/2005.09007

it takes all the intermediate feature representations, resize and then incorporate them into final prediction

Mut1nyJD commented 2 years ago

@lucidrains

So some initial results on my testing I will provide some more thorough number and details in due course. But the current finding with at least one dataset is when looking at qualitative metrics (e.g. FID) using pytorch-fid and 1500 samples shows that there is a gain (leading to lower FID scores) when using CrossEmbedLayer instead of a large kernel conv2d as initial layer and there is another gain (leading to lower FID scores) when using CrossEmbedLayer also in the Downsampling branch of the ResNet blocks. The benefit of the later seems to manifest itself earlier (already at lower iteration numbers) while the benefit of just in the initial step seems to take longer.