harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.11k stars 93 forks source link

Possible indexing error #59

Closed kmkurn closed 4 years ago

kmkurn commented 4 years ago

Hi, in the line below, shouldn't the arange go to N+1? Otherwise the last diagonal element (new_logits[:, -1, -1]) isn't set to -1e9.

https://github.com/harvardnlp/pytorch-struct/blob/2d27cad78b949f2a0add9826c1efc5b8b8190d36/torch_struct/deptree.py#L16

srush commented 4 years ago

I think you are correct that is a bug. I don't think it changes the outcome though? Or if it does I am surprised the unit test does not fail. This code also needs to be documented better.

kmkurn commented 4 years ago

I think it depends on the case. For me, I was converting the output of DependencyCRF.argmax which had the the last token connected to the root token and trying to get the head indices by doing argmax over the head dim, but the result had the last token connected to itself because both row 0 and N-1 were 1. Probably not the intended usage of this function though.

srush commented 4 years ago

Great. Can you send me a one-line PR? And if you are feeling particularly nice a unit test.

Let me know if there is anything else the library can support.