lucidrains / flamingo-pytorch

Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch
MIT License
1.21k stars 59 forks source link

Data Leakage on cross attention #14

Open DanielWarfield1 opened 9 months ago

DanielWarfield1 commented 9 months ago

Hello! Thanks for making this.

I was looking through MaskedCrossAttention, and I noticed that you generate the key and value using a dense network via

k, v = self.to_kv(media).chunk(2, dim = -1)

After this point you calculate the attention matrix, build the masks, etc.

My question is, isn't the point of flamingo that media information from only the immediately preceding image sequence attends to a particular textual token? If all media is passed through a dense network to generate the key and value, doesn't that imply that any information on any media could be present at any point within the final key and value? If so, it seems to me that masking would then be moot, as you're attempting to mask an abstract embedding of all media inputs by location, which is no longer relevant. Am I missing something?