bdusell / semiring-einsum

Generic PyTorch implementation of einsum that supports different semirings
https://bdusell.github.io/semiring-einsum/
MIT License
45 stars 7 forks source link

Use torch.amax if available #20

Closed davidweichiang closed 2 years ago

davidweichiang commented 2 years ago

torch.amax can take a max over multiple dims, which ought to be faster than using torch.max on one dim at a time.

I don't think there's a similar solution for log_viterbi_einsum_forward, since amax doesn't return an argmax.