I was looking through MaskedCrossAttention, and I noticed that you generate the key and value using a dense network via
k, v = self.to_kv(media).chunk(2, dim = -1)
After this point you calculate the attention matrix, build the masks, etc.
My question is, isn't the point of flamingo that media information from only the immediately preceding image sequence attends to a particular textual token? If all media is passed through a dense network to generate the key and value, doesn't that imply that any information on any media could be present at any point within the final key and value? If so, it seems to me that masking would then be moot, as you're attempting to mask an abstract embedding of all media inputs by location, which is no longer relevant. Am I missing something?
Hello! Thanks for making this.
I was looking through
MaskedCrossAttention
, and I noticed that you generate the key and value using a dense network viaAfter this point you calculate the attention matrix, build the masks, etc.
My question is, isn't the point of flamingo that media information from only the immediately preceding image sequence attends to a particular textual token? If all media is passed through a dense network to generate the key and value, doesn't that imply that any information on any media could be present at any point within the final key and value? If so, it seems to me that masking would then be moot, as you're attempting to mask an abstract embedding of all media inputs by location, which is no longer relevant. Am I missing something?