Closed adrian-spataru closed 3 years ago
Ok, we could use https://github.com/CompVis/taming-transformers Since it uses the same idea. Create a codebook from VQVAE for modelling images in Transformers.
@adrian-spataru Hi Adrian again! So I don't think they are doing VQVAE anymore, nor anything hierarchical. They simply softly discretize using gumbel-softmax. I'll sketch out the code today
@adrian-spataru Ok, I've put down the main ideas of the paper in code, let me know what you think! I'll get the sampling / ranking code done later today too
https://colab.research.google.com/drive/1cWcW0rQODZrceNiHGw8WB97DmP3URk4P?usp=sharing
I guess we could use/try openais clip weights? Still struggling with jit and the different names of the params i guess ill figure it out eventually. Would x-transformer be compatible with vit weights? wheres the qkv in openai CLIP? clip could be useful to sort through the dataset before training even?
@ViktorAlm Regarding the qkv, reading your colab, they have a module called MultiheadAttention. Most likely you'll find it in there.
@lucidrains First of all, great job man! That was fast. I think lot of Researchers, PhD Students etc gonna appreciate the effort you did so far in the AI Community. ( I am very grateful of your x-transformers )
Regarding code, I see what you mean with "softly" discretize. However when you pass into the transformer, you are just taking the argmax, therefore the expresiveness is kinda lost?
def get_codebook_indices(self, images):
logits = self.forward(images, return_logits = True)
codebook_indices = logits.argmax(dim = 1).flatten(1)
return codebook_indices
So therefore, instead of using softmax to predict the next token, one should use sigmoid and treat it as a multi-label classification problem. But that again is not ideal, since its a weighted mix aka not binary labels. I don't know, am I missing something?
@ViktorAlm Regarding the qkv, reading your colab, they have a module called MultiheadAttention. Most likely you'll find it in there.
@lucidrains First of all, great job man! That was fast. I think lot of Researchers, PhD Students etc gonna appreciate the effort you did so far in the AI Community. ( I am very grateful of your x-transformers )
Regarding code, I see what you mean with "softly" discretize. However when you pass into the transformer, you are just taking the argmax, therefore the expresiveness is kinda lost?
def get_codebook_indices(self, images): logits = self.forward(images, return_logits = True) codebook_indices = logits.argmax(dim = 1).flatten(1) return codebook_indices
So therefore, instead of using softmax to predict the next token, one should use sigmoid and treat it as a multi-label classification problem. But that again is not ideal, since its a weighted mix aka not binary labels. I don't know, am I missing something?
Thanks for the kind words! I believe they are doing a similar setup as Jukebox. The soft discretization only happens during pre-training of the VAE. You then encode the images of your training set as the indices of the learned codebook, and that goes under training with the decoder in DALL-E. I don't believe it is trained end to end, but I could be wrong
Hmm, as for your last point, I don't think it is multi-label classification
@lucidrains I see, however the difference is that in Jukebox, you build a separate model for generating the music. Yes, we are doing the same here with the Transformer, but the big difference is that we still using the original VQVAE decoder to reconstruct the image which we generated. The issue is that the VQVAE Decoder is being fed with binary values and not with the "softly" discrete values (we could say they are bounded continuous values). My fear is that the image reconstruction will be poor, due to this.
That said this just my assumption, if the thing works like this, all my comments can be thrown away.
I am also curious about soft discretization during inference
@adrian-spataru if you think on the iGPT paper, they hard clustered into 512 discrete tokens, and it still worked in the end with the sheer force of attention. My judgement may be clouded by my undue faith, however :)
If we cannot get it to work, I'm willing to build whatever you or someone else propose
Edit: Is this close to what you meant? https://github.com/lucidrains/DALLE-pytorch/commit/0c75b48ac10d3955867ca61366d26ad6ff81e613 (end to end training of vae and decoder)
Edit 2: https://github.com/lucidrains/DALLE-pytorch/tree/end-to-end Released it as package dalle-pytorch-dev
. I asked some researcher friends, and they are not sure if it will work. Probably should just try it
pretrained
been following this discussion for a while now and I am also wondering about the codebook part.
The original VQVAE measures vectorial distance of every latent vector to every codebook vector (rows of the codebook) and keeps the closest one for quantization purposes. The selected index might change however, if during training, the latent vectors change and are then closer to another codebook row. So far so good.
In the code of this repo however, after gumbel softmax, there is a contraction along the feature dimension, combing the latent vectors feature dimension with the individual columns of the codebook in a multiplicative manner. Does this mean that the dot product is effectively evaluated between latent vectors and separate columns of the codebook instead of comparing them with the codebooks feature dimensions (i.e. along the rows) ? Seems odd to me
Ok, paper is out. No need to guess: https://arxiv.org/pdf/2102.12092.pdf
We can use openAI CLIP implementation to filter the good samples, but I would assume they didn*t used it to create the embedding. So therefore we could assume they used some kind of VQ-VAE? For example https://github.com/openai/vdvae or https://github.com/NVlabs/NVAE ?
So this GIT should have 2-step Training Step 1 - Pretrained a autoencoder to tokenize the images. We could go small first and do it with a 16x16 Embedding and a relatively low vocab size. (2k-4k?) Step 2 - Train the Decoder-Transformer. Here we should have a preprocessing step to convert the image-text pairs to tokens. Some Huggingface tokenizer for Text and the encoder of VQ-VAE for the image.
We hope that someone will offer a pretrained model weights for CLIP to remove bad samples during Inference. If it was trained on something like the Microsoft Dataset, then it should be general enough for most usecases.
Some Open Questions: