lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

Question about the concated tokens (where is the `noised image token`?) #71

Closed CiaoHe closed 2 years ago

CiaoHe commented 2 years ago

Hi Phil, when reading the DiffusionPriorNetwork forward part, I noticed the concated tokens feed into the CausalTransformer are composed like below: https://github.com/lucidrains/DALLE2-pytorch/blob/fd53fa17db37dcec2e89c334da3fffcd89285ff7/dalle2_pytorch/dalle2_pytorch.py#L775-L780

But, refer to the original paper in Section2.2, it wrote as ...consisting of encoded text, the CLIP text embedding, an embedding for the diffusion timestep, the noised CLIP image embedding, and a final embedding whose output from the Transformer is used to predict the unnoised CLIP image embedding., I just wonder which part belongs to the the noised CLIP image embedding (maybe learned_queries?) It just confuses me.

Enjoy!

xiankgx commented 2 years ago

We are trying to predict the CLIP image embeddings. Then how can we add the noised version as input?

CiaoHe commented 2 years ago

We are trying to predict the CLIP image embeddings. Then how can we add the noised version as input?

  1. when do sampling, each time you parsed the x_t(image_embed_t) into the DiffusionPriorNetwork, then x_t(image_embed_t) will be the noised_image_embed. (Initially, the x_T(image_embed_t) just set as random)
  2. when do training, it's more clear to feed the image_embed into the PriorNet(since we can access to image in the train part)
xiankgx commented 2 years ago

I don't understand you. The purpose of the Prior is to predict a range of CLIP image embeddings given inputs:

In that case, then how do we pass CLIP image embeddings to Prior network to noise it? (your point #1).

Also, the purpose of the entire pipeline Prior -> Decoder is so we can input to the Prior:

We have access to image and hence CLIP image embedding during training, but how do we obtain this during test time when access to image is not available should we use CLIP image embedding (or some noised version) as an input to the Prior?

CiaoHe commented 2 years ago

I don't understand you. The purpose of the Prior is to predict a range of CLIP image embeddings given inputs:

  • CLIP text embeddings, and
  • optionally, text

In that case, then how do we pass CLIP image embeddings to Prior network to noise it? (your point #1).

Also, the purpose of the entire pipeline Prior -> Decoder is so we can input to the Prior:

  • text
  • CLIP text embedding obtained from text. The Prior should then predict the CLIP image embedding, to be decoded by the decoder to generate some image.

We have access to image and hence CLIP image embedding during training, but how do we obtain this during test time when access to image is not available should we use CLIP image embedding (or some noised version) as an input to the Prior?

Let me clarify it step by step. When sampling. You use p_sample_loop() right? p_sample_loop() just call p_sample() to finish backward progress(generate from noise to clear one). So the init img_embed (as wrote in line881) is random initialized. https://github.com/lucidrains/DALLE2-pytorch/blob/fd53fa17db37dcec2e89c334da3fffcd89285ff7/dalle2_pytorch/dalle2_pytorch.py#L877-L885 So, when do sampling, func p_sample() will call p_sample_variance() to get \mu and \sigma for sampling:https://github.com/lucidrains/DALLE2-pytorch/blob/fd53fa17db37dcec2e89c334da3fffcd89285ff7/dalle2_pytorch/dalle2_pytorch.py#L870, so the x is just image_emb, text-related information all included in text_cond(dict type). Then, in p_sample_variance(), it will forward the PriorNet https://github.com/lucidrains/DALLE2-pytorch/blob/fd53fa17db37dcec2e89c334da3fffcd89285ff7/dalle2_pytorch/dalle2_pytorch.py#L849 What is x here? I think it still should be image_emb. Next, jump into the PriorNet's forward, and we can see it parses image_embed in,https://github.com/lucidrains/DALLE2-pytorch/blob/fd53fa17db37dcec2e89c334da3fffcd89285ff7/dalle2_pytorch/dalle2_pytorch.py#L727-L729


So, my point is: in inference time, the image_emb just be initialized since we don't have any image. During PriorNet generating process, the image_emb will be refined(or say generated) by using text(or text_emb) information. Once you get generated image_emb, the rest of the work just passes to the Decoder part.

But anyway, in the current version of the PriorNet forward process, I cannot see the image_emb join the combined token which will be fed into the Causal Transformer. This is what I am concerned.

lucidrains commented 2 years ago

@CiaoHe oh gosh, i was rereading the code over and over again because i simply couldn't believe i made a big mistake like this; missing the most critical part of the diffusion prior, which is attending to the previous noised image embeddings. must be getting old :laughing: should be fixed https://github.com/lucidrains/DALLE2-pytorch/commit/85ed77d512e692bd2ffd0c72519b6385e77df208

thank you kindly for catching this (and you've caught so many mistakes i've made across repositories by now :pray: , for Alphafold2 and others)

lucidrains commented 2 years ago

@CiaoHe even mentioned it twice in the comments! https://github.com/lucidrains/DALLE2-pytorch/blob/0.1.6/dalle2_pytorch/dalle2_pytorch.py#L768 just failed to concat it for attention, my god :facepalm:

xiankgx commented 2 years ago

Oops, yeah, seems like the x_t or noised_image_embedding wasn't there.

lucidrains commented 2 years ago

@nousr you will probably want to redo the diffusion prior experiments given this issue! sorry!

lucidrains commented 2 years ago

@CiaoHe if you ever visit san francisco let me know, drinks (or coffee) is on me :smile:

CiaoHe commented 2 years ago

@lucidrains Haha, thanks for your attention. I learned a lot from your codes and really want to make a little contribution. And, thanks for your invitation, if I have a chance, I will thank you in person