kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
935 stars 151 forks source link

Make compatible with torchscript #112

Closed erksch closed 1 year ago

erksch commented 1 year ago

Hey there!

Many people want this repository to be compatible with TorchScript #100 #79 to be available for mobile inference. We found a simple and non-breaking solution and hope that it can be merged.

We added a test that scripts the CRF module. Because the scripted module should also include the decode method we added the @torch.jit.export annotation.

We had to make two other changes to make it work:

E       RuntimeError: 
E       cannot statically infer the expected size of a list in this context:
E         File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 156
E                   if emissions.shape[:2] != tags.shape:
E                       raise ValueError(
E                           'the first two dimensions of emissions and tags must match, '
E                            ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
E                           f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
E           
E       'CRF._validate' is being compiled since it was called from 'CRF.forward'
E         File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 90
E                   reduction is ``none``, ``()`` otherwise.
E               """
E               self._validate(emissions, tags=tags, mask=mask)
E               ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
E               if reduction not in ('none', 'sum', 'mean', 'token_mean'):
E                   raise ValueError(f'invalid reduction: {reduction}')
E       RuntimeError: 
E       
E       reversed(Tensor 0) -> Tensor 0:
E       Expected a value of type 'Tensor' for argument '0' but instead found type 'List[Tensor]'.
E       Empty lists default to List[Tensor]. Add a variable annotation to the assignment to create an empty list of another type (torch.jit.annotate(List[T, []]) where T is the type of elements in the list for Python 2)
E       :
E         File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 335
E                   # We trace back where the best last tag comes from, append that to our best tag
E                   # sequence, and trace it back again, and so on
E                   for hist in reversed(history[:seq_ends[idx]]):
E                               ~~~~~~~~ <--- HERE
E                       best_last_tag = hist[idx][best_tags[-1]]
E                       best_tags.append(best_last_tag.item())
E       'CRF._viterbi_decode' is being compiled since it was called from 'CRF.decode'
E         File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 140
E                   mask = mask.transpose(0, 1)
E           
E               return self._viterbi_decode(emissions, mask)
E                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
E       RuntimeError: 
E       Unsupported operation: indexing tensor with unsupported index type 'number'. Only ints, slices, lists and tensors are supported:
E         File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 337
E                   # sequence, and trace it back again, and so on
E                   for hist in history[:seq_ends[idx]][::-1]:
E                       best_last_tag = hist[idx][best_tags[-1]]
E                                       ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
E                       best_tags.append(best_last_tag.item())
E           
E       'CRF._viterbi_decode' is being compiled since it was called from 'CRF.decode'
E         File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 141
E                   mask = mask.transpose(0, 1)
E           
E               return self._viterbi_decode(emissions, mask)
E                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
erksch commented 1 year ago

Lol there already existed a PR doing the same exact same things! Then nevermind and hope it gets merged in the future :)