lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.55k stars 643 forks source link

DALL-E Image Embedding #1

Closed adrian-spataru closed 3 years ago

adrian-spataru commented 3 years ago

A token is any symbol from a discrete vocabulary; for humans, each English letter is a token from a 26-letter alphabet. DALL·E’s vocabulary has tokens for both text and image concepts. Specifically, each image caption is represented using a maximum of 256 BPE-encoded tokens with a vocabulary size of 16384, and the image is represented using 1024 tokens with a vocabulary size of 8192. The images are preprocessed to 256x256 resolution during training. Similar to VQVAE, each image is compressed to a 32x32 grid of discrete latent codes using a discrete VAE that we pretrained using a continuous relaxation. We found that training using the relaxation obviates the need for an explicit codebook, EMA loss, or tricks like dead code revival, and can scale up to large vocabulary sizes.

We can use openAI CLIP implementation to filter the good samples, but I would assume they didn*t used it to create the embedding. So therefore we could assume they used some kind of VQ-VAE? For example https://github.com/openai/vdvae or https://github.com/NVlabs/NVAE ?

So this GIT should have 2-step Training Step 1 - Pretrained a autoencoder to tokenize the images. We could go small first and do it with a 16x16 Embedding and a relatively low vocab size. (2k-4k?) Step 2 - Train the Decoder-Transformer. Here we should have a preprocessing step to convert the image-text pairs to tokens. Some Huggingface tokenizer for Text and the encoder of VQ-VAE for the image.

We hope that someone will offer a pretrained model weights for CLIP to remove bad samples during Inference. If it was trained on something like the Microsoft Dataset, then it should be general enough for most usecases.

Some Open Questions:

adrian-spataru commented 3 years ago

Ok, we could use https://github.com/CompVis/taming-transformers Since it uses the same idea. Create a codebook from VQVAE for modelling images in Transformers.

lucidrains commented 3 years ago

@adrian-spataru Hi Adrian again! So I don't think they are doing VQVAE anymore, nor anything hierarchical. They simply softly discretize using gumbel-softmax. I'll sketch out the code today

lucidrains commented 3 years ago

@adrian-spataru Ok, I've put down the main ideas of the paper in code, let me know what you think! I'll get the sampling / ranking code done later today too

ViktorAlm commented 3 years ago

https://colab.research.google.com/drive/1cWcW0rQODZrceNiHGw8WB97DmP3URk4P?usp=sharing

I guess we could use/try openais clip weights? Still struggling with jit and the different names of the params i guess ill figure it out eventually. Would x-transformer be compatible with vit weights? wheres the qkv in openai CLIP? clip could be useful to sort through the dataset before training even?

adrian-spataru commented 3 years ago

@ViktorAlm Regarding the qkv, reading your colab, they have a module called MultiheadAttention. Most likely you'll find it in there.

@lucidrains First of all, great job man! That was fast. I think lot of Researchers, PhD Students etc gonna appreciate the effort you did so far in the AI Community. ( I am very grateful of your x-transformers )

Regarding code, I see what you mean with "softly" discretize. However when you pass into the transformer, you are just taking the argmax, therefore the expresiveness is kinda lost?

    def get_codebook_indices(self, images):
        logits = self.forward(images, return_logits = True)
        codebook_indices = logits.argmax(dim = 1).flatten(1)
        return codebook_indices

So therefore, instead of using softmax to predict the next token, one should use sigmoid and treat it as a multi-label classification problem. But that again is not ideal, since its a weighted mix aka not binary labels. I don't know, am I missing something?

lucidrains commented 3 years ago

@ViktorAlm Regarding the qkv, reading your colab, they have a module called MultiheadAttention. Most likely you'll find it in there.

@lucidrains First of all, great job man! That was fast. I think lot of Researchers, PhD Students etc gonna appreciate the effort you did so far in the AI Community. ( I am very grateful of your x-transformers )

Regarding code, I see what you mean with "softly" discretize. However when you pass into the transformer, you are just taking the argmax, therefore the expresiveness is kinda lost?

    def get_codebook_indices(self, images):
        logits = self.forward(images, return_logits = True)
        codebook_indices = logits.argmax(dim = 1).flatten(1)
        return codebook_indices

So therefore, instead of using softmax to predict the next token, one should use sigmoid and treat it as a multi-label classification problem. But that again is not ideal, since its a weighted mix aka not binary labels. I don't know, am I missing something?

Thanks for the kind words! I believe they are doing a similar setup as Jukebox. The soft discretization only happens during pre-training of the VAE. You then encode the images of your training set as the indices of the learned codebook, and that goes under training with the decoder in DALL-E. I don't believe it is trained end to end, but I could be wrong

Hmm, as for your last point, I don't think it is multi-label classification

adrian-spataru commented 3 years ago

@lucidrains I see, however the difference is that in Jukebox, you build a separate model for generating the music. Yes, we are doing the same here with the Transformer, but the big difference is that we still using the original VQVAE decoder to reconstruct the image which we generated. The issue is that the VQVAE Decoder is being fed with binary values and not with the "softly" discrete values (we could say they are bounded continuous values). My fear is that the image reconstruction will be poor, due to this.

That said this just my assumption, if the thing works like this, all my comments can be thrown away.

diff7 commented 3 years ago

I am also curious about soft discretization during inference

lucidrains commented 3 years ago

@adrian-spataru if you think on the iGPT paper, they hard clustered into 512 discrete tokens, and it still worked in the end with the sheer force of attention. My judgement may be clouded by my undue faith, however :)

If we cannot get it to work, I'm willing to build whatever you or someone else propose

Edit: Is this close to what you meant? https://github.com/lucidrains/DALLE-pytorch/commit/0c75b48ac10d3955867ca61366d26ad6ff81e613 (end to end training of vae and decoder)

Edit 2: https://github.com/lucidrains/DALLE-pytorch/tree/end-to-end Released it as package dalle-pytorch-dev. I asked some researcher friends, and they are not sure if it will work. Probably should just try it

lucidrains commented 3 years ago
Screen Shot 2021-01-08 at 10 02 17 AM

pretrained

CDitzel commented 3 years ago

been following this discussion for a while now and I am also wondering about the codebook part.

The original VQVAE measures vectorial distance of every latent vector to every codebook vector (rows of the codebook) and keeps the closest one for quantization purposes. The selected index might change however, if during training, the latent vectors change and are then closer to another codebook row. So far so good.

In the code of this repo however, after gumbel softmax, there is a contraction along the feature dimension, combing the latent vectors feature dimension with the individual columns of the codebook in a multiplicative manner. Does this mean that the dot product is effectively evaluated between latent vectors and separate columns of the codebook instead of comparing them with the codebooks feature dimensions (i.e. along the rows) ? Seems odd to me

adrian-spataru commented 3 years ago

Ok, paper is out. No need to guess: https://arxiv.org/pdf/2102.12092.pdf