kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
938 stars 151 forks source link

Compute Viterbi Decode in parallel on minibatch #16

Closed JeppeHallgren closed 6 years ago

JeppeHallgren commented 6 years ago

Currently each sample in the minibatch is processed serially in viterbi_decode. The below change removes the for-loop and instead computes the best paths for all the samples at once.

For my specific use case (minibatch_size = 50, num_tags = 85, avg sequence length = 1000), this gave a 3x speed-up when executed on a GPU.

Tests and linters pass.

coveralls commented 6 years ago

Coverage Status

Coverage remained the same at 100.0% when pulling 9c57104ced2658db6757da49908ceb788050b38d on JeppeHallgren:feature/batched-viterbi-decode into 839299625cd2fdae4a9b3d6aa87c230010f7961e on kmkurn:master.

kmkurn commented 6 years ago

Hi, thank you so much for this!

This looks really great, but I can't accept it without tests. Can you please add tests to check if the result from a batched Viterbi decoding is exactly the same as the non-batched one? Also, just to confirm, the 3x speed up was compared to the non-batched version and both versions were run on a GPU, right?

JeppeHallgren commented 6 years ago

@kmkurn cool - yes, the output of the batched version should be exactly the same as the un-batched. If we get a good way of manually setting the internal parameters ( #18 :) ), so the model is not randomly initialized, I can add a test that checks that the output of viterbi decode didn't change before and after this PR.

Yes, the speedup was compared to the unbatched version and both cases were run on a GPU.

kmkurn commented 6 years ago

Great! Anyway, I think you can test by simulating a non-batched version via a loop and running the batched version, and then check if the two results are equal. That shouldn't require setting the internal parameters.

Yes, the speedup was compared to the unbatched version and both cases were run on a GPU.

Awesome!

JeppeHallgren commented 6 years ago

test by simulating a non-batched version via a loop and running the batched version, and then check if the two results are equal.

@kmkurn I did this on my local machine and it seems to work. I can't do this in a test though since the old non-batched version no longer exist after the above change?

kmkurn commented 6 years ago

What I meant is something like

from torchcrf import CRF
import torch

batch_size, seq_len, num_tags = 2, 3, 4
crf = CRF(num_tags)
emissions = torch.randn(batch_size, seq_len, num_tags)
mask = torch.ByteTensor([[1, 1, 1], [1, 1, 0]])

# non-batched
non_batched = []
for emissions_, mask_ in zip(emissions, mask):
    # shape: (1, seq_len, num_tags)
    emissions_ = emissions_.unsqueeze(0)
    # shape: (1, seq_len, num_tags)
    mask_ = mask_.unsqueeze(0)

    result = crf.decode(emissions_, mask=mask_)
    assert len(result) == 1
    non_batched.append(result[0])

# batched
batched = crf.decode(emissions, mask=mask)

assert non_batched == batched

I probably messed up the dimension order but you get the idea.

JeppeHallgren commented 6 years ago

@kmkurn I fixed the dimensions order and added the test.

JeppeHallgren commented 6 years ago

@kmkurn let me know if you want other changes here

JeppeHallgren commented 6 years ago

@kmkurn All your comments should have been fixed in the latest commit.