It seems that the look_back function defined in utils.py gives wrong result for the first bucket - last bucket is attended. For example, if sequence length is 6 and bucket size is 2 (so that the number of buckets is 3), then the result of an input (q_0, q_1, q_2, q_3, q_4, q_5) becomes something like ((q_4, q_5, q_0, q_1), (q_0, q_1, q_2, q_3), (q_2, q_3, q_4, q_5)). Although I can't understand the official reformer implementation by Trax of Google, this seems to be fixed somehow. (mask out q_4 and q_5?)
It seems that the look_back function defined in utils.py gives wrong result for the first bucket - last bucket is attended. For example, if sequence length is 6 and bucket size is 2 (so that the number of buckets is 3), then the result of an input (q_0, q_1, q_2, q_3, q_4, q_5) becomes something like ((q_4, q_5, q_0, q_1), (q_0, q_1, q_2, q_3), (q_2, q_3, q_4, q_5)). Although I can't understand the official reformer implementation by Trax of Google, this seems to be fixed somehow. (mask out q_4 and q_5?)