bdusell / semiring-einsum

Generic PyTorch implementation of einsum that supports different semirings
https://bdusell.github.io/semiring-einsum/
MIT License
44 stars 8 forks source link

(feat): Support For Complex Numbers #32

Closed ilan-gold closed 1 year ago

ilan-gold commented 1 year ago

Hello! This seems like it would be a huge boost for me, but I need support for complex numbers - any chance I could make a PR? Right now I'm getting:

a = tensor([[[9.9567e-01, 9.9567e-01, 9.9567e-01,  ..., 9.9567e-01,
          9.9567e-01, 9.9567e-01],
         [9.9567e-0...       [1.6774e-05, 1.6774e-05, 1.6774e-05,  ..., 1.6774e-05,
          1.6774e-05, 1.6774e-05]]], dtype=torch.float64)
b = tensor([[[ 0.0000+0.0000j, -0.0024-0.0698j, -0.0097-0.1392j,  ...,
          -0.0219+0.2079j, -0.0097+0.1392j, -0.0024...698j, -0.0097-0.1392j,  ...,
          -0.0219+0.2079j, -0.0097+0.1392j, -0.0024+0.0698j]]],
       requires_grad=True)

    def multiply_in_place(a, b):
>       a.mul_(b)
E       RuntimeError: result type ComplexDouble can't be cast to the desired output type Double
ilan-gold commented 1 year ago

Or rather, my use-case is one real-valued argument, and one complex.

ilan-gold commented 1 year ago

It seems like one option is just flipping the order of operations here, unless I am mistaken.

bdusell commented 1 year ago

Ah, I see what's happening. a inside multiply_in_place gets its dtype from the first a, so it's trying to write complex numbers to a float tensor in-place. It should work if you set the dtype of the first a to torch.cdouble, with the imaginary part set to 0.

bdusell commented 1 year ago

Did that help?

ilan-gold commented 1 year ago

I did not end up trying that, but flipping my operations did, so just went with that. Thank you for the reply though! Much appreicated!