lucidrains / perceiver-ar-pytorch

Implementation of Perceiver AR, Deepmind's new long-context attention network based on Perceiver architecture, in Pytorch
MIT License
86 stars 4 forks source link

Cross attention should be over the whole seq and smaller seq #5

Closed kashif closed 1 year ago

kashif commented 1 year ago

In your code you split the sequence into a prefix and smaller window and calculate the cross-attention with respect to it...

https://github.com/lucidrains/perceiver-ar-pytorch/blob/685d77d152c55ef7210336566b952de7da631f68/perceiver_ar_pytorch/perceiver_ar_pytorch.py#L284

However, in the diagram of the method, the whole sequence is used for the V and K... Can you kindly confirm?

Thank you!

lucidrains commented 1 year ago

@kashif yup, well aware!

I reattach the key / values of the sequence to the prefix being cross attended to, which is equivalent to the entire sequence

https://github.com/lucidrains/perceiver-ar-pytorch/blob/685d77d152c55ef7210336566b952de7da631f68/perceiver_ar_pytorch/perceiver_ar_pytorch.py#L167

kashif commented 1 year ago

right! i missed that thanks!