lucidrains / perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
MIT License
1.08k stars 134 forks source link

Decoder Attention Module needs a FF network as well in perceiver_io.py script #38

Closed Hritikbansal closed 3 years ago

Hritikbansal commented 3 years ago

Hi,

According to perceiver io paper's (https://arxiv.org/abs/2107.14795) architectural details, they mention that the decoder attention block contains a cross attention block (4), which is already implemented in the perceiver_io.py script (Line 151), followed by a Feedforward network, given by equation (6) in the paper, which is not present in that script. I am not aware of the repercussions of not having FF in the decoder module but it might be a good idea to have it in the implementation. Something like self.decoder_ff = PreNorm(FeedForward(queries_dim)) would do the job. Experimentally, the authors had found that omitting equation (5) is helpful.

lucidrains commented 3 years ago

@Hritikbansal indeed! I missed this detail, and I've added it into the latest version :) thanks!

Hritikbansal commented 3 years ago

Hi @lucidrains, thanks for taking a note of this. I looked at the latest change, and I believe that it should be latents = self.decoder_ff(latents) + latents instead of latents = self.decoder_ff(latents) as per equation (6) in Line 182. Thanks a lot for your work on this implementation! It is really helpful :)

lucidrains commented 3 years ago

@Hritikbansal No problem! So I actually reverted it for now, because I am reading their public implementation https://github.com/deepmind/deepmind-research/blob/master/perceiver/perceiver.py#L620 and I don't see this feedforward

lucidrains commented 3 years ago

@Hritikbansal not sure what's going on, so I made it optional with a decoder_ff flag! hope that is good for you!