harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.11k stars 93 forks source link

Device mismatch in LinearChainCRF #111

Closed kmkurn closed 3 years ago

kmkurn commented 3 years ago

This line below causes device mismatch error if chart is in GPU because init will always be created in CPU. https://github.com/harvardnlp/pytorch-struct/blob/e51fecc1473925e4c44de135c4a3240fcb20fa40/torch_struct/linearchain.py#L56 I think the fix would be torch.zeros_like(chart).bool().

srush commented 3 years ago

Thanks. Need to get the gpu tests in the CI.