kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
935 stars 151 forks source link

Not working for TorchScript tracing #79

Closed erksch closed 3 years ago

erksch commented 3 years ago

Hey there!

Thank you for the awesome module! I want to use the module in a mobile setting and need to trace my model to use it for TorchScript. But sadly the implementation does not seem to allow it at the moment.

If you use model that uses the module internally and try to trace it, it can not generalize from the dummy input.

import torch

# dummy input for tracing
dummy_seq_len = 5
dummy_input = torch.rand(1, dummy_seq_len)

traced = torch.jit.trace(model, dummy_input)

# testing if the traced model also works for other seq lens
x = torch.rand(1, 3)
y = traced(x)
y.shape # (1, 5) is the output of the decode function but as a tensor

x = torch.rand(1, 12)
y = traced(x)
y.shape # (1, 5)

x = torch.rand(1, 20)
y = traced(x)
y.shape # (1, 5)

As you can see the traced model only supports the dummy input sequence length. This is because the implementation of viterbi decode uses loops and python values instead of relying only on Tensor operations. You can see it in the warnings when tracing the model:

ner/crf.py:275: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  for i in range(1, seq_length):
ner/crf.py:310: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  for idx in range(batch_size):
ner/crf.py:314: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  best_tags = [best_last_tag.item()]
ner/crf.py:318: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  for hist in reversed(history[:seq_ends[idx]]):
ner/crf.py:320: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

I hope this problem seems reasonable for someone who is maybe not familiar with tracing and torchscript. Do you think there is a way to implement the viterbi decode as plain tensor operations in order to fix this problem?

erksch commented 3 years ago

I don't even know if it is possible to implement the Viterbi algorithm without a loop since it is a dynamic programming algorithm. Maybe one has to let go of sequence lengths of variable sizes and use a fixed size and padding tokens instead.

kmkurn commented 3 years ago

Hi, I'm not too familiar with TorchScript and never used it, but from your explanation it seems TorchScript isn't able to trace variable input length? Looks like a serious limitation :/

As you said, I'm not sure either if it's possible to get rid of the loop in Viterbi... Using fixed size could be a workaround, but it seems unnecessarily inefficient memory-wise. Not sure if there's a better alternative.

erksch commented 3 years ago

I've managed to reimplement the score calculation without a loop with only tensor operations and that part can be traced. But the history path generation afterwards is extremely tricky. I'll give it a try tomorrow. If that would be solved also that it should work for variable inputs.

erksch commented 3 years ago

If I understand the path generation correctly it boils down to this (simplified version):

a = torch.Tensor([[0, 1, 1],
                  [2, 1, 0],
                  [1, 0, 1]]).long()
path = torch.zeros(4, dtype=torch.long)

start_idx = 0
path[0] = start_idx

for i in range(a.shape[0]):
    path[i + 1] = a[i, trace[i]]

print(path)
# tensor([0, 0, 2, 1])

If that could be done with only tensor operations the problem would be solved.

erksch commented 3 years ago

Actually I'm a dumb dumb and aside from tracing there is a second alternative for creating the TorchScript. You can annotate functions with @torch.jit.script or wrap a function with it. Nevertheless I made some optimizations to the code and could make a PR if you want that would reduce the number of loops and lists to a minimum and uses mostly plain tensor operations. That would work for tracing out of the box and stoll allows variable inputs.

kmkurn commented 3 years ago

Good to hear you've found a solution! A PR is definitely welcome! (some benchmark numbers will also be much appreciated) But it'll probably a while before I can look into it as I have limited availability at the moment.

Also, I'm closing the issue as you've worked out a solution.

etern commented 2 years ago

Both scripting and tracing failed for me. @erksch can you share your solution in more detail ?

etern commented 2 years ago

Finally tried it out by eval torch.jit.script(CRF()) and let it raise error

changes shown in the following diff: (PS: I don't know my file version, just neglect the line number) (PS: I'm using pytorch 1.9)

```diff --- a/crf.py +++ b/crf.py @@ -134,8 +134,9 @@ class CRF(nn.Module): assert reduction == 'token_mean' return llh.sum() / mask.float().sum() + @torch.jit.export def decode(self, emissions: torch.Tensor, - mask: Optional[torch.ByteTensor] = None) -> List[List[int]]: + mask: Optional[torch.ByteTensor] = None) -> torch.Tensor: """Find the most likely tag sequence using Viterbi algorithm. Args: @@ -150,7 +151,7 @@ class CRF(nn.Module): """ self._validate(emissions, mask=mask) if mask is None: - mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) + mask = torch.ones(emissions.shape[:2], dtype=torch.uint8, device=emissions.device) if self.batch_first: emissions = emissions.transpose(0, 1) @@ -174,13 +175,13 @@ class CRF(nn.Module): if emissions.shape[:2] != tags.shape: raise ValueError( 'the first two dimensions of emissions and tags must match, ' - f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}') + f'got {emissions.shape[:2]} and {tags.shape}') if mask is not None: if emissions.shape[:2] != mask.shape: raise ValueError( 'the first two dimensions of emissions and mask must match, ' - f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}') + f'got {emissions.shape[:2]} and {mask.shape}') no_empty_seq = not self.batch_first and mask[0].all() no_empty_seq_bf = self.batch_first and mask[:, 0].all() if not no_empty_seq and not no_empty_seq_bf: @@ -277,7 +278,7 @@ class CRF(nn.Module): return torch.logsumexp(score, dim=1) def _viterbi_decode(self, emissions: torch.FloatTensor, - mask: torch.ByteTensor) -> List[List[int]]: + mask: torch.ByteTensor) -> torch.Tensor: # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 @@ -290,7 +291,7 @@ class CRF(nn.Module): # Start transition and first emission # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] - history = [] + history: List[torch.Tensor] = [] # score is a tensor of size (batch_size, num_tags) where for every batch, # value at column j stores the score of the best tag sequence so far that ends @@ -322,7 +323,7 @@ class CRF(nn.Module): # Set score to the next score if this timestep is valid (mask == 1) # and save the index that produces the next score # shape: (batch_size, num_tags) - score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score) + score = torch.where(mask[i].unsqueeze(1).to(torch.bool), next_score, score) history.append(indices) # End transition score @@ -333,22 +334,22 @@ class CRF(nn.Module): # shape: (batch_size,) seq_ends = mask.long().sum(dim=0) - 1 - best_tags_list = [] + best_tags_list: List[List[int]] = [] for idx in range(batch_size): # Find the tag which maximizes the score at the last timestep; this is our best tag # for the last timestep _, best_last_tag = score[idx].max(dim=0) - best_tags = [best_last_tag.item()] + best_tags: List[int] = [int(best_last_tag.item())] # We trace back where the best last tag comes from, append that to our best tag # sequence, and trace it back again, and so on - for hist in reversed(history[:seq_ends[idx]]): + for hist in history[:seq_ends[idx]][::-1]: best_last_tag = hist[idx][best_tags[-1]] - best_tags.append(best_last_tag.item()) + best_tags.append(int(best_last_tag.item())) # Reverse the order because we start from the last timestep best_tags.reverse() best_tags_list.append(best_tags) - return best_tags_list + return torch.tensor(best_tags_list) ```
MrRace commented 1 year ago

What is the solution? @erksch