Open CDitzel opened 3 years ago
@CDitzel Good timing! I'm about to get back to work on DALL-E today and tomorrow, going to make the training easy for everyone :)
https://github.com/lucidrains/DALLE-pytorch/releases/tag/0.0.54 I've released a new version where I turn off tying embeddings, and if they are turned on, I detach it properly so it doesn't get trained. Thanks for catching that!
thank you for attending to this so quickly!
Is a separate embedding for the text and the image tokens even necessary?
I saw similar implementations where they would just concat the tokens and pass them over to a transformer that features only one single nn.Embedding
yup you can do one single embedding! you would just need to offset one set of tokens by the number in the other
i don't think it matters too much :)
for now, let's keep it separate, so it could be optionally tied (or not)
yup you can do one single embedding! you would just need to offset one set of tokens by the number in the other
i don't think it matters too much :)
do you really? I believe one could just index into one and the same embeddings with indices of both modalities even though they span identical integer ranges
@CDitzel ohh, well, i meant you would do something like nn.Embedding(num_text_tokens + num_image_tokens, dim)
then, when it comes time to retrieve the embedding image_token_ids += num_text_tokens
yeah I understood what you meant. But I think just using
nn.Embedding(larger_num_of_both_token_len,dim)
and then index into that with both tokens equally even though this means that every so often a text token and an image token could retrieve the same embedding vector
https://github.com/lucidrains/DALLE-pytorch/blob/40f41199b3f4a355108c64db3ef018d9271bb131/dalle_pytorch/dalle_pytorch.py#L290
right now, neighter an apropritate no_grad call nor manually disabling codebook.requiresgrad(False) prevents the pretrained VAE codebook from getting further adjusted during the subsequent DALLE training procedure.
I am in doube if this is meant to be the case.
Training of the VAE encoder part is rightfully disabled by the associated decorator https://github.com/lucidrains/DALLE-pytorch/blob/40f41199b3f4a355108c64db3ef018d9271bb131/dalle_pytorch/dalle_pytorch.py#L122
but this does not pertain to the codebook. Maybe I am missing something here? Just wanted to draw the attention to this point