bdusell / semiring-einsum

Generic PyTorch implementation of einsum that supports different semirings
https://bdusell.github.io/semiring-einsum/
MIT License
45 stars 7 forks source link

log_viterbi_einsum_forward raises exception with no summed-out variables #16

Closed davidweichiang closed 2 years ago

davidweichiang commented 2 years ago
eq = compile_equation('a,a->a')
x = torch.arange(5, dtype=float)
y = torch.arange(5, dtype=float)

log_viterbi_einsum_forward(eq, x, y, block_size=1)

Expected output: (tensor([0., 2., 4., 6., 8.], dtype=torch.float64), tensor([], size=(5, 0)))

Actual output:

  File ".../torch_semiring_einsum/log_viterbi_forward.py", line 102, in max_argmax_block
    argmax = torch.stack(argmaxes, dim=-1)
RuntimeError: stack expects a non-empty TensorList