bdusell / semiring-einsum

Generic PyTorch implementation of einsum that supports different semirings
https://bdusell.github.io/semiring-einsum/
MIT License
44 stars 8 forks source link

Allow zero-dimensional arguments #8

Closed davidweichiang closed 2 years ago

davidweichiang commented 2 years ago

Closes #7.

This turned out to be a bug in PyTorch:

>>> torch.tensor(1.)[()]
tensor(1.)
>>> torch.tensor(1.)[[]]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: too many indices for tensor of dimension 0
>>> torch.tensor([1.])[(0,)]
tensor(1.)
>>> torch.tensor([1.])[[0]]
tensor([1.])

This PR just works around that bug.