Closed L-Hugh closed 4 years ago
@L-Hugh you are absolutely right about points 1 and 2! I have made the changes and they are live! I initially transcribed that code from the official Trax repository, and that was the way they had it
The custom function is to "connect" the gradients across the discontinuity of the unsorting the results of the bucketed logits from the LSH
@L-Hugh for 3, on closer examination, it turns out the way I transcribed over was for TPU specific optimizations https://github.com/google/trax/blob/master/trax/layers/research/efficient_attention.py#L1264 combined with your suggestions with number 2, I removed the class altogether!
@L-Hugh Thank you for pointing this out!
@lucidrains Thanks for your reply!
Thanks for your implementation! I have some questions:
If https://github.com/lucidrains/reformer-pytorch/blob/9851ccb2017ff4f64a5733045775861b788f69dc/reformer_pytorch/reformer_pytorch.py#L276 can be replaced by
_, undo_sort = sticker.sort(dim=-1)
? I think they are the same.If https://github.com/lucidrains/reformer-pytorch/blob/9851ccb2017ff4f64a5733045775861b788f69dc/reformer_pytorch/reformer_pytorch.py#L403 can be replaced by
logits = slogits.gather(1, undo_sort, dim=-1)
?What's the reason for using custom class? https://github.com/lucidrains/reformer-pytorch/blob/9851ccb2017ff4f64a5733045775861b788f69dc/reformer_pytorch/reformer_pytorch.py#L397
Thanks again for your work! Please correct me if I'm wrong.