lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.15k stars 1.09k forks source link

Base parameters for replication of decoder #137

Open Veldrovive opened 2 years ago

Veldrovive commented 2 years ago

This run has verified distributed training for the first unet of the decoder and I would like to set out some standard parameters that can be experimented with going forward.

The above run uses embeddings from vit-l-14 and has unet parameters: dim=512 dim_mults=(1, 2, 3, 4) attn_dim_head = 32 attn_heads = 16 resnet_groups = 8 num_resnet_blocks = 2 init_cross_embed_kernel_sizes = (3, 7, 15)

I have no reference point for what could be changed to improve performance. If anyone has magic knowledge of what hyperparameter values might improve performance it would really go a long way to making the guesswork process quicker.

lucidrains commented 2 years ago

@Veldrovive it is safe to just go with the same hyperparameters as what Imagen has, as Imagen outperforms DALLE2 anyways. we know at the very least that scaling the unets is unnecessary