lucidrains / perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
MIT License
1.1k stars 134 forks source link

Weight sharing not consistent with paper #67

Open gshaikov opened 1 year ago

gshaikov commented 1 year ago

Hi Phil,

Want to confirm the reason behind this design choice: https://github.com/lucidrains/perceiver-pytorch/blob/c3d505a997a6e3521e83d7d2bf57cb8b62e3fbd6/perceiver_pytorch/perceiver_pytorch.py#L194-L210

In the paper, they say that they tie all the latent transformer weights. However in this implementation, TF in the first layer is not shared with the rest.

image

It should probably be

            for block_ind in range(self_per_cross_attn):
                self_attns.append(nn.ModuleList([
                    get_latent_attn(_cache=True, key = block_ind),
                    get_latent_ff(_cache=True, key = block_ind)
                ]))

What do you think?