Many people want this repository to be compatible with TorchScript #100 #79 to be available for mobile inference.
We found a simple and non-breaking solution and hope that it can be merged.
We added a test that scripts the CRF module.
Because the scripted module should also include the decode method we added the @torch.jit.export annotation.
We had to make two other changes to make it work:
emissions.shape[:2] != tags.shape is not scriptable and throws the error cannot statically infer the expected size of a list in this context. We changed it so that the two dimensions are checked individually.
E RuntimeError:
E cannot statically infer the expected size of a list in this context:
E File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 156
E if emissions.shape[:2] != tags.shape:
E raise ValueError(
E 'the first two dimensions of emissions and tags must match, '
E ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
E f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
E
E 'CRF._validate' is being compiled since it was called from 'CRF.forward'
E File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 90
E reduction is ``none``, ``()`` otherwise.
E """
E self._validate(emissions, tags=tags, mask=mask)
E ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
E if reduction not in ('none', 'sum', 'mean', 'token_mean'):
E raise ValueError(f'invalid reduction: {reduction}')
same for emissions.shape[:2] != mask.shape
reversed(history[:seq_ends[idx]]) in _viterbi_decode seemed not to be supported, we changed to history[:seq_ends[idx]][::-1]
E RuntimeError:
E
E reversed(Tensor 0) -> Tensor 0:
E Expected a value of type 'Tensor' for argument '0' but instead found type 'List[Tensor]'.
E Empty lists default to List[Tensor]. Add a variable annotation to the assignment to create an empty list of another type (torch.jit.annotate(List[T, []]) where T is the type of elements in the list for Python 2)
E :
E File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 335
E # We trace back where the best last tag comes from, append that to our best tag
E # sequence, and trace it back again, and so on
E for hist in reversed(history[:seq_ends[idx]]):
E ~~~~~~~~ <--- HERE
E best_last_tag = hist[idx][best_tags[-1]]
E best_tags.append(best_last_tag.item())
E 'CRF._viterbi_decode' is being compiled since it was called from 'CRF.decode'
E File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 140
E mask = mask.transpose(0, 1)
E
E return self._viterbi_decode(emissions, mask)
E ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
best_last_tag.item() needs to be casted to an int, otherwise the inferred type can not be used to index hist[idx]
E RuntimeError:
E Unsupported operation: indexing tensor with unsupported index type 'number'. Only ints, slices, lists and tensors are supported:
E File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 337
E # sequence, and trace it back again, and so on
E for hist in history[:seq_ends[idx]][::-1]:
E best_last_tag = hist[idx][best_tags[-1]]
E ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
E best_tags.append(best_last_tag.item())
E
E 'CRF._viterbi_decode' is being compiled since it was called from 'CRF.decode'
E File "/home/erik/Projects/voize/pytorch-crf/torchcrf/__init__.py", line 141
E mask = mask.transpose(0, 1)
E
E return self._viterbi_decode(emissions, mask)
E ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
Hey there!
Many people want this repository to be compatible with TorchScript #100 #79 to be available for mobile inference. We found a simple and non-breaking solution and hope that it can be merged.
We added a test that scripts the CRF module. Because the scripted module should also include the decode method we added the
@torch.jit.export
annotation.We had to make two other changes to make it work:
emissions.shape[:2] != tags.shape
is not scriptable and throws the errorcannot statically infer the expected size of a list in this context
. We changed it so that the two dimensions are checked individually.emissions.shape[:2] != mask.shape
reversed(history[:seq_ends[idx]])
in_viterbi_decode
seemed not to be supported, we changed tohistory[:seq_ends[idx]][::-1]
best_last_tag.item()
needs to be casted to an int, otherwise the inferred type can not be used to indexhist[idx]