Closed kmkurn closed 3 years ago
This line below causes device mismatch error if chart is in GPU because init will always be created in CPU. https://github.com/harvardnlp/pytorch-struct/blob/e51fecc1473925e4c44de135c4a3240fcb20fa40/torch_struct/linearchain.py#L56 I think the fix would be torch.zeros_like(chart).bool().
chart
init
torch.zeros_like(chart).bool()
Thanks. Need to get the gpu tests in the CI.
This line below causes device mismatch error if
chart
is in GPU becauseinit
will always be created in CPU. https://github.com/harvardnlp/pytorch-struct/blob/e51fecc1473925e4c44de135c4a3240fcb20fa40/torch_struct/linearchain.py#L56 I think the fix would betorch.zeros_like(chart).bool()
.