lucidrains / reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
MIT License
2.13k stars 256 forks source link

Some questions about class LSHAttention(nn.Module) #109

Closed L-Hugh closed 4 years ago

L-Hugh commented 4 years ago

Thanks for your implementation! I have some questions:

  1. If https://github.com/lucidrains/reformer-pytorch/blob/9851ccb2017ff4f64a5733045775861b788f69dc/reformer_pytorch/reformer_pytorch.py#L276 can be replaced by _, undo_sort = sticker.sort(dim=-1)? I think they are the same.

  2. If https://github.com/lucidrains/reformer-pytorch/blob/9851ccb2017ff4f64a5733045775861b788f69dc/reformer_pytorch/reformer_pytorch.py#L403 can be replaced by logits = slogits.gather(1, undo_sort, dim=-1)?

  3. What's the reason for using custom class? https://github.com/lucidrains/reformer-pytorch/blob/9851ccb2017ff4f64a5733045775861b788f69dc/reformer_pytorch/reformer_pytorch.py#L397

    Thanks again for your work! Please correct me if I'm wrong.

lucidrains commented 4 years ago

@L-Hugh you are absolutely right about points 1 and 2! I have made the changes and they are live! I initially transcribed that code from the official Trax repository, and that was the way they had it

The custom function is to "connect" the gradients across the discontinuity of the unsorting the results of the bucketed logits from the LSH

lucidrains commented 4 years ago

@L-Hugh for 3, on closer examination, it turns out the way I transcribed over was for TPU specific optimizations https://github.com/google/trax/blob/master/trax/layers/research/efficient_attention.py#L1264 combined with your suggestions with number 2, I removed the class altogether!

lucidrains commented 4 years ago

@L-Hugh Thank you for pointing this out!

L-Hugh commented 4 years ago

@lucidrains Thanks for your reply!