Closed mhh0318 closed 2 years ago
@mhh0318 hi Michael and thank you for catching this! i believe this could be an important detail that was missing, so it is now added in 0.6.0 https://github.com/lucidrains/DALLE2-pytorch/commit/a0bed30a844f3007665a67349de4d44a97b4ebbe
there are two types of conditioning going on, one is where time is projected and then summed to the hiddens within the resnet blocks (or film-like conditioning). the other is cross attention, where both time and image embeddings is projected into tokens to be attended to by the resnet block hiddens. however, the cross attention may not be enough, since a lot of layers do not see it due to the computational constraints. 0.6.0 should allow the former type of conditioning to have both time and image, as you noted in the paper
it could also be the case that the projected image hiddens are summed to the time within the resnet blocks, but we can come back to that if things are still not functional
Hi @lucidrains! Thank you for open-sourcing this implementation.
I'm just wondering if you changed injecting the timestep embedding from concatenation to addition? I'm reading several papers now, and they all point back to the original DDPM model for how the timestep injection is added.
My understanding is that positional encodings are created as in Transformers, except in this case the "word position" is the diffusion timestep and the "word embedding index" is the image channel index. If this were the case, the same encoding value would be added across the image Height and Width for each channel and each timestep.
I.e. a matrix M of size (H, W, C, T) where M[:, :, c, t] is a matrix of identical values for any channel c and timestep t.
I'm wondering if this is your understanding as well and if this is how you implemented the timestep injection in the case that you updated it from the concatenation pointed out here.
Thanks so much!
Thanks for the sharing.
For the image decoder part, the description of DALLE2 paper is
Is that mean the embedding should be
instead of
https://github.com/lucidrains/DALLE2-pytorch/blob/387c5bf77494dec8d4d566343e893bc84ef30d03/dalle2_pytorch/dalle2_pytorch.py#L1617
I'm not sure if I'm correct as I haven't got any trained decoder model now🤦♂️
Best,