Closed kmkurn closed 4 years ago
I think you are correct that is a bug. I don't think it changes the outcome though? Or if it does I am surprised the unit test does not fail. This code also needs to be documented better.
I think it depends on the case. For me, I was converting the output of DependencyCRF.argmax
which had the the last token connected to the root token and trying to get the head indices by doing argmax over the head dim, but the result had the last token connected to itself because both row 0 and N-1 were 1. Probably not the intended usage of this function though.
Great. Can you send me a one-line PR? And if you are feeling particularly nice a unit test.
Let me know if there is anything else the library can support.
Hi, in the line below, shouldn't the arange go to N+1? Otherwise the last diagonal element (
new_logits[:, -1, -1]
) isn't set to -1e9.https://github.com/harvardnlp/pytorch-struct/blob/2d27cad78b949f2a0add9826c1efc5b8b8190d36/torch_struct/deptree.py#L16