Closed JeppeHallgren closed 6 years ago
Hi, thank you so much for this!
This looks really great, but I can't accept it without tests. Can you please add tests to check if the result from a batched Viterbi decoding is exactly the same as the non-batched one? Also, just to confirm, the 3x speed up was compared to the non-batched version and both versions were run on a GPU, right?
@kmkurn cool - yes, the output of the batched version should be exactly the same as the un-batched. If we get a good way of manually setting the internal parameters ( #18 :) ), so the model is not randomly initialized, I can add a test that checks that the output of viterbi decode didn't change before and after this PR.
Yes, the speedup was compared to the unbatched version and both cases were run on a GPU.
Great! Anyway, I think you can test by simulating a non-batched version via a loop and running the batched version, and then check if the two results are equal. That shouldn't require setting the internal parameters.
Yes, the speedup was compared to the unbatched version and both cases were run on a GPU.
Awesome!
test by simulating a non-batched version via a loop and running the batched version, and then check if the two results are equal.
@kmkurn I did this on my local machine and it seems to work. I can't do this in a test though since the old non-batched version no longer exist after the above change?
What I meant is something like
from torchcrf import CRF
import torch
batch_size, seq_len, num_tags = 2, 3, 4
crf = CRF(num_tags)
emissions = torch.randn(batch_size, seq_len, num_tags)
mask = torch.ByteTensor([[1, 1, 1], [1, 1, 0]])
# non-batched
non_batched = []
for emissions_, mask_ in zip(emissions, mask):
# shape: (1, seq_len, num_tags)
emissions_ = emissions_.unsqueeze(0)
# shape: (1, seq_len, num_tags)
mask_ = mask_.unsqueeze(0)
result = crf.decode(emissions_, mask=mask_)
assert len(result) == 1
non_batched.append(result[0])
# batched
batched = crf.decode(emissions, mask=mask)
assert non_batched == batched
I probably messed up the dimension order but you get the idea.
@kmkurn I fixed the dimensions order and added the test.
@kmkurn let me know if you want other changes here
@kmkurn All your comments should have been fixed in the latest commit.
Currently each sample in the minibatch is processed serially in viterbi_decode. The below change removes the for-loop and instead computes the best paths for all the samples at once.
For my specific use case (minibatch_size = 50, num_tags = 85, avg sequence length = 1000), this gave a 3x speed-up when executed on a GPU.
Tests and linters pass.