lucidrains / recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
MIT License
195 stars 14 forks source link

Handling of the time conditioning #9

Closed nicolas-dufour closed 1 year ago

nicolas-dufour commented 1 year ago

Hi, thanks for the implementation. I was trying to implement the RIN paper and was looking at this implementation for reference and i observed that the time conditioning was handled differently. In the paper, the say they just add it as one more latent token and here i see it's closer to other diffusion models handling of conditioning.

Is there some reason for this choice of implementation? Does it improve performances? Thanks!

lucidrains commented 1 year ago

@nicolas-dufour only that i was unaware of this! would be happy to make the correction, if you can point me to the relevant passage

do they discretize and select out an embedding, or construct it from an MLP?

nicolas-dufour commented 1 year ago

@lucidrains From the paper they say the following: "Conditioning variables, such as class labels and time step t of diffusion models, are mapped to embeddings; in our experiments, we simply concatenate them to the set of latents, since they only account for two tokens."

So i would say they map them with an MLP, and then concat to the rest of the tokens

lucidrains commented 1 year ago

@nicolas-dufour ok, should be in effect in 0.5.0! let me know if that looks ok to you or if you find any other issues :pray:

lucidrains commented 1 year ago

@nicolas-dufour are you planning on trying this out during your internship at Meta? would welcome you to share your experiments between adaptive layernorm vs latent token conditioning, if you are allowed to do so

nicolas-dufour commented 1 year ago

I'm not at meta anymore, but I'm working on this for my PhD will try to do some experiment and I will post here if i find anything!

lucidrains commented 1 year ago

@nicolas-dufour yes please, i'm very curious what is the better approach!

if it turns out latent token conditioning is better, it would be another 'attention is all we need' moment :smile:

lucidrains commented 1 year ago

@nicolas-dufour also, there was a time when i implemented DDPM incorrectly a long time ago, forgot to time condition, and yet it still converged and sampled decent pictures. i never dug into that, but thought i'd share

nicolas-dufour commented 1 year ago

@nicolas-dufour also, there was a time when i implemented DDPM incorrectly a long time ago, forgot to time condition, and yet it still converged and sampled decent pictures. i never dug into that, but thought i'd share

Hum the model might have learned an implicit representation of the noise level. From what i had understood, the time conditioning was introduced to make training easier, but it's not mandatory. Especially when you look at the math, you want a model that approximates the gradient of log(x_t) so conceptually there is no need for t been there.

ajabri commented 1 year ago

Concatenating the time and class embeddings to the set of latents (i.e. along the seq axis) is the default used in the paper. We've open-sourced the code, here is the architecture, in case it may be of help: https://github.com/google-research/pix2seq/blob/main/architectures/tape.py