lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.57k stars 642 forks source link

codebook keeps getting trained during DALLE training #35

Open CDitzel opened 3 years ago

CDitzel commented 3 years ago

https://github.com/lucidrains/DALLE-pytorch/blob/40f41199b3f4a355108c64db3ef018d9271bb131/dalle_pytorch/dalle_pytorch.py#L290

right now, neighter an apropritate no_grad call nor manually disabling codebook.requiresgrad(False) prevents the pretrained VAE codebook from getting further adjusted during the subsequent DALLE training procedure.

I am in doube if this is meant to be the case.

Training of the VAE encoder part is rightfully disabled by the associated decorator https://github.com/lucidrains/DALLE-pytorch/blob/40f41199b3f4a355108c64db3ef018d9271bb131/dalle_pytorch/dalle_pytorch.py#L122

but this does not pertain to the codebook. Maybe I am missing something here? Just wanted to draw the attention to this point

lucidrains commented 3 years ago

@CDitzel Good timing! I'm about to get back to work on DALL-E today and tomorrow, going to make the training easy for everyone :)

https://github.com/lucidrains/DALLE-pytorch/releases/tag/0.0.54 I've released a new version where I turn off tying embeddings, and if they are turned on, I detach it properly so it doesn't get trained. Thanks for catching that!

CDitzel commented 3 years ago

thank you for attending to this so quickly!

Is a separate embedding for the text and the image tokens even necessary?

I saw similar implementations where they would just concat the tokens and pass them over to a transformer that features only one single nn.Embedding

lucidrains commented 3 years ago

yup you can do one single embedding! you would just need to offset one set of tokens by the number in the other

i don't think it matters too much :)

lucidrains commented 3 years ago

for now, let's keep it separate, so it could be optionally tied (or not)

CDitzel commented 3 years ago

yup you can do one single embedding! you would just need to offset one set of tokens by the number in the other

i don't think it matters too much :)

do you really? I believe one could just index into one and the same embeddings with indices of both modalities even though they span identical integer ranges

lucidrains commented 3 years ago

@CDitzel ohh, well, i meant you would do something like nn.Embedding(num_text_tokens + num_image_tokens, dim)

then, when it comes time to retrieve the embedding image_token_ids += num_text_tokens

CDitzel commented 3 years ago

yeah I understood what you meant. But I think just using

nn.Embedding(larger_num_of_both_token_len,dim)

and then index into that with both tokens equally even though this means that every so often a text token and an image token could retrieve the same embedding vector