Rick-McCoy / Reformer-pytorch

Implements Reformer: The Efficient Transformer in pytorch.
MIT License
84 stars 8 forks source link

look_back function attends last bucket to the first bucket #3

Closed seewoo5 closed 4 years ago

seewoo5 commented 4 years ago

It seems that the look_back function defined in utils.py gives wrong result for the first bucket - last bucket is attended. For example, if sequence length is 6 and bucket size is 2 (so that the number of buckets is 3), then the result of an input (q_0, q_1, q_2, q_3, q_4, q_5) becomes something like ((q_4, q_5, q_0, q_1), (q_0, q_1, q_2, q_3), (q_2, q_3, q_4, q_5)). Although I can't understand the official reformer implementation by Trax of Google, this seems to be fixed somehow. (mask out q_4 and q_5?)

seewoo5 commented 4 years ago

Sorry this was a dumb question - that queries will be masked out since q_0, q_1 and q_4, q_5 may have different hashes.