lucidrains / recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
MIT License
194 stars 14 forks source link

Lacks a `ff` layer? #3

Closed CiaoHe closed 1 year ago

CiaoHe commented 1 year ago

Hi Phil! Compared to Algorithm3 Line 4, I found a little inconsistent, image

https://github.com/lucidrains/recurrent-interface-network-pytorch/blob/19862012f9f685d7a9dd6c2f88609b34f015dbf2/rin_pytorch/rin_pytorch.py#L310

maybe need

latents = self.latents_attend_to_patches(latents, patches, time = t) + latents 
latents = self.latents_cross_attn_ff(latents, time=t) + latents

just curious 😁

lucidrains commented 1 year ago

@CiaoHe Hi He Cao! Yes indeed all attention should be followed by feedforwards, thank you for catching this!