harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.11k stars 93 forks source link

Labeled projective dependency CRF #63

Closed kmkurn closed 4 years ago

kmkurn commented 4 years ago

This is work in progress and isn't ready to merge yet.

This seems to work for partition, but argmax and marginals don't return as I expect. Both return tensor of shape (B, N, N); I'd expect them to return (B, N, N, L) tensors instead. Any advice?

srush commented 4 years ago

Almost perfect. You just have to update _unconvert as well. It gets called in arrange marginals. If you could make it squeeze the last dimension as well that would keep the old code working.

kmkurn commented 4 years ago

I've updated _unconvert as suggested, but both argmax and marginals still have 3 dimensions :-(

UPDATE: I think I've got it.

srush commented 4 years ago

Awesome. Adding