lucidrains / perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
MIT License
1.1k stars 134 forks source link

Weight sharing implementation #58

Closed yuanmao closed 2 years ago

yuanmao commented 2 years ago

Thanks again for the great work. I have two questions related to the weight sharing implementation:

  1. In the original Perceiver, there are n blocks and each with m self-attention layers, with the i-th self-attention layer in each block sharing weights (https://github.com/deepmind/deepmind-research/blob/826ff89f21e5143dc68ff7cb33f01cc6e237844d/perceiver/perceiver.py#L392) How does it work in your implementation (https://github.com/lucidrains/perceiver-pytorch/blob/2d59df42ebb0b7538af77d584f5ae5b50759618b/perceiver_pytorch/perceiver_pytorch.py#L193) from the second block and on? For example, does the 2nd self-attn layer in the 2nd block share weights with the 2nd self-attn layer in the 1st block or does it also share weights with the 1st self-attn layer in the 2nd block?
  2. Weight sharing is done by repetitively calling the same nn.module in the original paper(https://github.com/deepmind/deepmind-research/blob/826ff89f21e5143dc68ff7cb33f01cc6e237844d/perceiver/perceiver.py#L470); while here it uses cache_fn to achieve that. I'm not so sure if they are equivalent and if so, do you see any performance benefit of using cache_fn method?
lucidrains commented 2 years ago

@yuanmao Hi there! You caught a bug, my apologies in advance if this caused any inconveniences for your research :cry: I have fixed it in 0.8.1 https://github.com/lucidrains/perceiver-pytorch/releases/tag/0.8.1