Closed erksch closed 3 years ago
I don't even know if it is possible to implement the Viterbi algorithm without a loop since it is a dynamic programming algorithm. Maybe one has to let go of sequence lengths of variable sizes and use a fixed size and padding tokens instead.
Hi, I'm not too familiar with TorchScript and never used it, but from your explanation it seems TorchScript isn't able to trace variable input length? Looks like a serious limitation :/
As you said, I'm not sure either if it's possible to get rid of the loop in Viterbi... Using fixed size could be a workaround, but it seems unnecessarily inefficient memory-wise. Not sure if there's a better alternative.
I've managed to reimplement the score calculation without a loop with only tensor operations and that part can be traced. But the history path generation afterwards is extremely tricky. I'll give it a try tomorrow. If that would be solved also that it should work for variable inputs.
If I understand the path generation correctly it boils down to this (simplified version):
a = torch.Tensor([[0, 1, 1],
[2, 1, 0],
[1, 0, 1]]).long()
path = torch.zeros(4, dtype=torch.long)
start_idx = 0
path[0] = start_idx
for i in range(a.shape[0]):
path[i + 1] = a[i, trace[i]]
print(path)
# tensor([0, 0, 2, 1])
If that could be done with only tensor operations the problem would be solved.
Actually I'm a dumb dumb and aside from tracing there is a second alternative for creating the TorchScript. You can annotate functions with @torch.jit.script
or wrap a function with it. Nevertheless I made some optimizations to the code and could make a PR if you want that would reduce the number of loops and lists to a minimum and uses mostly plain tensor operations. That would work for tracing out of the box and stoll allows variable inputs.
Good to hear you've found a solution! A PR is definitely welcome! (some benchmark numbers will also be much appreciated) But it'll probably a while before I can look into it as I have limited availability at the moment.
Also, I'm closing the issue as you've worked out a solution.
Both scripting and tracing failed for me. @erksch can you share your solution in more detail ?
Finally tried it out by eval torch.jit.script(CRF())
and let it raise error
changes shown in the following diff: (PS: I don't know my file version, just neglect the line number) (PS: I'm using pytorch 1.9)
What is the solution? @erksch
Hey there!
Thank you for the awesome module! I want to use the module in a mobile setting and need to trace my model to use it for TorchScript. But sadly the implementation does not seem to allow it at the moment.
If you use model that uses the module internally and try to trace it, it can not generalize from the dummy input.
As you can see the traced model only supports the dummy input sequence length. This is because the implementation of viterbi decode uses loops and python values instead of relying only on Tensor operations. You can see it in the warnings when tracing the model:
I hope this problem seems reasonable for someone who is maybe not familiar with tracing and torchscript. Do you think there is a way to implement the viterbi decode as plain tensor operations in order to fix this problem?