Closed JeppeHallgren closed 6 years ago
@kmkurn not sure how you want to handle this. The GPU test is not executed by the coverage bot, since it doesn't have a GPU, and hence why it reports lower coverage. Could remove it, but seems nice to have GPU tests.
Hi, thanks again for the PR!
To move the parameters to GPU, users can already do so by:
from torchcrf import CRF
crf = CRF(5)
crf.cuda(0) # this will move parameters to GPU device 0 recursively
Every PyTorch's module (i.e. those who inherit from torch.nn.Module
) will have that .cuda()
method. So, I think the use_gpu
flag is not needed.
Makes sense, thanks.
Adds a use_gpu flag to the CRF constructor which, if set to True, will place the internal variables on the GPU. This means that viterbi_decode and forward can now be called with data already on the GPU, leading to significant speed gains - see #16 .
Tests and linters pass.