kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
938 stars 151 forks source link

GPU support #17

Closed JeppeHallgren closed 6 years ago

JeppeHallgren commented 6 years ago

Adds a use_gpu flag to the CRF constructor which, if set to True, will place the internal variables on the GPU. This means that viterbi_decode and forward can now be called with data already on the GPU, leading to significant speed gains - see #16 .

Tests and linters pass.

coveralls commented 6 years ago

Coverage Status

Coverage decreased (-7.5%) to 92.531% when pulling 7f368bd6d6aad53359e309a45f00b6a46dea597d on JeppeHallgren:feature/gpu-support into 839299625cd2fdae4a9b3d6aa87c230010f7961e on kmkurn:master.

JeppeHallgren commented 6 years ago

@kmkurn not sure how you want to handle this. The GPU test is not executed by the coverage bot, since it doesn't have a GPU, and hence why it reports lower coverage. Could remove it, but seems nice to have GPU tests.

kmkurn commented 6 years ago

Hi, thanks again for the PR!

To move the parameters to GPU, users can already do so by:

from torchcrf import CRF

crf = CRF(5)
crf.cuda(0)  # this will move parameters to GPU device 0 recursively

Every PyTorch's module (i.e. those who inherit from torch.nn.Module) will have that .cuda() method. So, I think the use_gpu flag is not needed.

JeppeHallgren commented 6 years ago

Makes sense, thanks.