lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.55k stars 643 forks source link

General questions to the algorithmic understanding #21

Open CDitzel opened 3 years ago

CDitzel commented 3 years ago

Been trying to get a grasp of the DALLE code recently. However, there are a couple of things, I cant quite wrap my head around and since the paper is not published yet, I was wondering, if we can maybe clarify them here.

So there is the VAE training which basically features the codebook in the bottleneck and is trained a priori.

Next, Dalle receives text and image pairs, embeds them and adds positional encodings individually to both modalities. However, the image data is not embedded like e.g. in ViT but by downsampling it via the Encoder of the VAE (without accumulating gradients), argmax search within the feature dimension across the downsampled image patches and finally indexing into the previously trained codebook.

The resulting representations of both modalities are then concatenated along the token dimension. And while every word of the text input is one token, the height and width of the VAE-encoded image yields the number of image tokens.

The combined embedding is then passed into a single transformer which calculates self-attention not only intra-modal but also across both modalities if I am not mistaken.

A masking of the form

mask = torch.ones_like(text).bool()

results in unmasked attention calculation, right?

A final Mlp maps the transformer output to all potential token possibilities (both text and image).

Then I dont understand the masking

     logits_mask = (
            ((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) |
            ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) |
            ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1)))
        )

shouldnt there be one more row concerned with the text input and one less row for the image input?

For the following config with 3 text input tokens

vae = DiscreteVAE(
    image_size = 64,
    num_layers = 5,
    num_tokens = 10,
    codebook_dim = 256,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
).cuda()

dalle = DALLE(
    dim = 256,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 4,    # vocab size for text
    text_seq_len = 3,         # text sequence length
    depth = 6,                 # should aim to be 64
    heads = 8,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
).cuda()

text = torch.randint(0, 4, (1, 3)).cuda()
images = torch.randn(1, 3, 64, 64).cuda()
mask = torch.ones_like(text).bool().cuda()

the mask looks like this

tensor([[[False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False]]])

shouldt it be?

tensor([[[False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False]]])

The purpose of the masking is so that image tokens dont contribute to the predictions of text and vice versa. The code proceeds by constructing labels from the text integer tokens and the VAE image embedding pixels by using the codebook indices.

But what is it we are actually trying to predict with this classification task here? It is a 2d CrossEntropyLoss where for each token (either text or image) we are trying to predict ... exactly what? Some I am missing the intuition here I guess...

And then, why is the label vector neglecting the very first label entry but using the EOS enty?

    **loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:])**

Maybe someone can help me (and others) in understanding better whats going on here. Thank you in advance

lucidrains commented 3 years ago

@CDitzel Hey Carsten! Glad you are reading the repository and double checking my work!

So I think the best way to make clear is to work through a concrete example. Let's say each text token is just a character, and let's pretend numbers are visual tokens.

say we are working with a multi modal sequence

[c] [a] [t] [4] [1] [2] [3]

what we do during training is first append an EOS

[c] [a] [t] [4] [1] [2] [3] [eos]

then, for autoregressive, we break this sequence from the range [:-1] to predict [1:]

[a] [t] [4] [1] [2] [3] [eos]

<bottom tokens predict top>

[c] [a] [t] [4] [1] [2] [3]

you can see, since the last text token predicts the start of the first visual token, that is why in the logits mask, you see that it is off by one

T mask logit for only text tokens I mask logit for only image tokens

[T] [T] [I] [I] [I] [I] [I]

[a] [t] [4] [1] [2] [3] [eos]

<bottom tokens predict top>

[c] [a] [t] [4] [1] [2] [3]

As for your second question, we are always trying to predict the next token. Even without the masking, the attention net will eventually figure it out. This just makes it a bit cleaner, and makes sure during sampling we don't hit the wrong token.

lucidrains commented 3 years ago

The EOS token is strictly not necessary if you assume fixed length (256 + 1024), but I just wanted to give the last token something to predict that makes sense

lucidrains commented 3 years ago

@CDitzel it is also a possibility that DALL-E was trained with full attention on the text tokens, in which case, I may simplify this in the future so that text tokens are not included in the logit space at all :)

lucidrains commented 3 years ago

@CDitzel So after I wrote up all that, I realized that perhaps my implementation was needlessly complex

I decided to switch over from using an EOS to having a BOS (assuming 0 as padding and as BOS) https://github.com/lucidrains/DALLE-pytorch/pull/22 This also gets rid of some off-by-one confusion (but also introduces some tensor slicing here and there)

let me know if that makes more sense!

CDitzel commented 3 years ago

thank you for your effort and time Phil. I will have a look tomorrow and get back to you.

Again so many thanks!

CDitzel commented 3 years ago

Hi Phil, so after carefully going through your explanations and the changes you recently merged, things start to become cleaner I guess. I still hope you dont mind me asking a couple of further questions.

  1. Do I understand correctly that prepending the BOS token to the beginning of the text sequence causes the related embedding to actually predict the first real text token? So the last text token of the input sequence ultimately results in a prediction for the first image token (the first codebook index of the VAE)?

  2. Does this mean that the transformer actually never sees the embedding of the very last image token?

  3. The correct last codebook index is predicted instead by the second to the last image token, right?

  4. So there is only one cross-modal prediction happening right? The remaining predictions either predict from text to text or image to image in a modal-isolated fashion? This seems rather arbitrary. Wouldnt it make more sense to enforce stronger intermodal-connections? I get that during attention layers, both modalities can attend to each other freely, but the final prediction procedure still only has this one interface between text and #image...

  5. Is the codebook not only trained while VAE pre-training but also then further adapted during the training process of Dalle?

  6. What is the intuition of using argmax to finding the largest entry across the feature dimension of the VAE encoder for every remaining spatially downsampled image pixel? Is this maybe corresponding to finding the feature map with the highest response? I am looking for a vivid reason for using argmax to extract the codebook indices

  7. Stripping of the sophisticated transformer extensions like reformer and sparse attention, the basic procedure should also work with a standard transformer, I guess, right?

  8. The way I see it, codebook_dim and dim of the transformers have to be of equal size due to the subsequent concatenation operation. Is this always a mandatory?

These again are a lot of question, but it is very rare to find knowledgeable people that also helpful, so I decided to take my chance. Thank you so much in advance!

lucidrains commented 3 years ago

Hi Phil, so after carefully going through your explanations and the changes you recently merged, things start to become cleaner I guess. I still hope you dont mind me asking a couple of further questions.

  1. Do I understand correctly that prepending the BOS token to the beginning of the text sequence causes the related embedding to actually predict the first real text token? So the last text token of the input sequence ultimately results in a prediction for the first image token (the first codebook index of the VAE)?

Correct!

  1. Does this mean that the transformer actually never sees the embedding of the very last image token?

Yup, we are assuming a fixed length generation, so once we hit the last image token, we stop trying to predict the next one

  1. The correct last codebook index is predicted instead by the second to the last image token, right?

Yup correct

  1. So there is only one cross-modal prediction happening right? The remaining predictions either predict from text to text or image to image in a modal-isolated fashion? This seems rather arbitrary. Wouldnt it make more sense to enforce stronger intermodal-connections? I get that during attention layers, both modalities can attend to each other freely, but the final prediction procedure still only has this one interface between text and #image...

So in this specific setup, because text tokens precede image tokens, the image tokens can attend to all the text tokens (but not the other way around). However, one can imagine a future system where you don't have such restrictions and just mix all tokens from all modalities together in any order

  1. Is the codebook not only trained while VAE pre-training but also then further adapted during the training process of Dalle?

I believe the codebook is pretrained only in the VAE, and the DALL-E trains its own embeddings for the visual tokens

  1. What is the intuition of using argmax to finding the largest entry across the feature dimension of the VAE encoder for every remaining spatially downsampled image pixel? Is this maybe corresponding to finding the feature map with the highest response? I am looking for a vivid reason for using argmax to extract the codebook indices

The VAE pre-training should encourage the encoder to discretize the image to unique codebook entries across the latent feature map.

  1. Stripping of the sophisticated transformer extensions like reformer and sparse attention, the basic procedure should also work with a standard transformer, I guess, right?

Yup, it's exactly like GPT, but for text and image tokens. Nothing complicated

  1. The way I see it, codebook_dim and dim of the transformers have to be of equal size due to the subsequent concatenation operation. Is this always a mandatory?

So this is a mistake on my part, I'll remove this restriction. I had thought perhaps there was a way to share codebook embeddings between DALL-E and VAE, but I don't think that would work

These again are a lot of question, but it is very rare to find knowledgeable people that also helpful, so I decided to take my chance. Thank you so much in advance!

No problem! I'm learning as I go as well, so these questions are helpful for me to think out loud

CDitzel commented 3 years ago

thank you once again for your answers and the possibility to discuss matters here.

I believe the codebook is pretrained only in the VAE, and the DALL-E trains its own embeddings for the visual tokens

mh but right now, imho the codebook is also adjusted during DALL-E training...

So in this specific setup, because text tokens precede image tokens, the image tokens can attend to all the text tokens (but not the other way around). However, one can imagine a future system where you don't have such restrictions and just mix all tokens from all modalities together in any order

I cannot follow. The attention of the transformer receives the input which is text and image tokens, concatenated along the token dimension. But since the mask during training has TRUE set everywhere, this is full fletched attention from every token to every other, isnt it? maybe I didnt understand properly.

Another thing which I dont get is, why there are two BOS tokens prepended. Once in the generate_images function

https://github.com/lucidrains/DALLE-pytorch/blob/6a50564387bb518d1e4c2ad9988e0e1cd09225ef/dalle_pytorch/dalle_pytorch.py#L340

and then again here

https://github.com/lucidrains/DALLE-pytorch/blob/6a50564387bb518d1e4c2ad9988e0e1cd09225ef/dalle_pytorch/dalle_pytorch.py#L379

because of this

https://github.com/lucidrains/DALLE-pytorch/blob/6a50564387bb518d1e4c2ad9988e0e1cd09225ef/dalle_pytorch/dalle_pytorch.py#L400

doesnt this cause the first image token to be predicted not by the last but by the second to the last text token?

lucidrains commented 3 years ago

thank you once again for your answers and the possibility to discuss matters here.

I believe the codebook is pretrained only in the VAE, and the DALL-E trains its own embeddings for the visual tokens

mh but right now, imho the codebook is also adjusted during DALL-E training...

yea, someone else actually brought that up. I don't believe so, because if you read the iGPT paper, they clustered the pixel space into 512 values and then simply retrained on those 512 values as unique embeddings, and it still worked. however, I have a branch in this repository named 'end-to-end' that contains what you are describing and you are free to try it out

So in this specific setup, because text tokens precede image tokens, the image tokens can attend to all the text tokens (but not the other way around). However, one can imagine a future system where you don't have such restrictions and just mix all tokens from all modalities together in any order

I cannot follow. The attention of the transformer receives the input which is text and image tokens, concatenated along the token dimension. But since the mask during training has TRUE set everywhere, this is full fletched attention from every token to every other, isnt it? maybe I didnt understand properly.

so the attention is only from future to past because the causal flag is turned on https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/transformer.py#L86 https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py#L294

Another thing which I dont get is, why there are two BOS tokens prepended. Once in the generate_images function

https://github.com/lucidrains/DALLE-pytorch/blob/6a50564387bb518d1e4c2ad9988e0e1cd09225ef/dalle_pytorch/dalle_pytorch.py#L340

and then again here

https://github.com/lucidrains/DALLE-pytorch/blob/6a50564387bb518d1e4c2ad9988e0e1cd09225ef/dalle_pytorch/dalle_pytorch.py#L379

because of this

https://github.com/lucidrains/DALLE-pytorch/blob/6a50564387bb518d1e4c2ad9988e0e1cd09225ef/dalle_pytorch/dalle_pytorch.py#L400

doesnt this cause the first image token to be predicted not by the last but by the second to the last text token?

that's a bug on my part, fixed in the latest commit! :pray:

CDitzel commented 3 years ago

I am wondering, if instead of text one has another image modality, say for example the left image of a pair of stereo cameras where the right image has been used to train the VAE, how would one go about using this in DALL-E? According to the discussion section of this repo, the camera image has to be tokenized.

I am contemplating whether it makes more sense to use another VAE for the second stream of images and rely on its resulting codebook indices or if it is more reasonable to use e.g. a ViT prior to token concatenation and feeding into the main transformer of DALL-E? Maybe even a simple trainable ViT Embedding layer within the forward pass of DALL-E before the concatenation process suffices?

I am just spitballing here and would be grateful for yours or anyone else's take on this

snoop2head commented 2 years ago

It seems like dalle-pytorch has used BPE based on individual letters, rather than using GPT-3 or BART encoder. Have I understood it correctly?

This issue helped out a lot! Thanks 🤗