lucidrains / recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
MIT License
193 stars 14 forks source link

Possible Normalization Bug #6

Closed MattMcPartlon closed 1 year ago

MattMcPartlon commented 1 year ago

Thanks for another great implementation, Phil!

You're using the Attention Block to do attention between latent features (i.e. "process" step from RIN paper). It looks like you're not Layer-Normalizing the context features in Attention when no context is provided (a logical move :)).

When you initialize the latent attention blocks in RINBlock, You specify norm=True, so you're layer-normalizing the latent features before computing the query vectors. Unfortunately, the context is set to the unnormalized latent features, which are then used to compute keys and values.

I could be wrong, or this could be intentional. Just wanted to give you a heads-up.

TLDR: When you try to use the Attention class to do regular Attention (not Cross Attention), the features used to predict keys and values may not be normalized.

MattMcPartlon commented 1 year ago

Simpler explanation : line 199 moves to line 205

Screen Shot 2023-02-27 at 4 50 18 PM
lucidrains commented 1 year ago

@MattMcPartlon Matt! Long time no talk! Thank you for raising this issue, and hope things are going well at the protein design company :smile: are you using this for your bio work? :open_mouth:

MattMcPartlon commented 1 year ago

Thanks Phil. Always a pleasure to read over your code 😄. I just saw your email the other day (my bad!). Will respond soon.

I was hoping to do diffusion with your perceiver architecture, then I heard about this paper.

I'm hoping to use some of the ideas here to build a molecular-graph autoencoder. The Encoder will alternate cross attention with node and pair features. (kv=node->q=latent),(kv=latent->q=latent),(kv=pair->q=latent),(kv=latent->q=latent),[repeat], and decoder will use the same approach as PerceiverIO.

I'm not sure this type of graph can be encoded in a fixed-dimension latent space, but it seems worth a try :).

lucidrains commented 1 year ago

@MattMcPartlon yes indeed! this is actually the first time i've seen something perceiver-like used for generations of this fidelity

hope you are able to get it to work for your molecular autoencoder