williamFalcon / pytorch-complex-tensor

Unofficial complex tensor and scalar support for Pytorch
MIT License
78 stars 8 forks source link

Pytorch Complex Tensor

Unofficial complex Tensor support for Pytorch

PyPI version

How it works

Treats first half of tensor as real, second as imaginary. A few arithmetic operations are implemented to emulate complex arithmetic. Supports gradients.

Installation

pip install pytorch-complex-tensor

Example:

Easy import

from pytorch_complex_tensor import ComplexTensor

Init tensor

# equivalent to:
# np.asarray([[1+3j, 1+3j, 1+3j], [2+4j, 2+4j, 2+4j]]).astype(np.complex64)
C = ComplexTensor([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
C.requires_grad = True

Pretty printing

print(C)
# tensor([['(1.0+3.0j)' '(1.0+3.0j)' '(1.0+3.0j)'],
#         ['(2.0+4.0j)' '(2.0+4.0j)' '(2.0+4.0j)']])

handles absolute value properly for complex tensors

# complex absolute value implementation
print(C.abs())
# tensor([[3.1623, 3.1623, 3.1623],
#         [4.4721, 4.4721, 4.4721]], grad_fn=<SqrtBackward>)

prints correct sizing treating first half of matrix as real, second as imag

print(C.size())
# torch.Size([2, 3])

multiplies both complex and real tensors

# show matrix multiply with real tensor
# also works with complex tensor
x = torch.Tensor([[3, 3], [4, 4], [2, 2]])
xy = C.mm(x)
print(xy)
# tensor([['(9.0+27.0j)' '(9.0+27.0j)'],
#         ['(18.0+36.0j)' '(18.0+36.0j)']])

reduce ops return ComplexScalar

xy = xy.sum()

# this is now a complex scalar (thin wrapper with .real, .imag)
print(type(xy))
# pytorch_complex_tensor.complex_scalar.ComplexScalar

print(xy)
# (54+126j)

which can be used for gradients without breaking anything... (differentiates wrt the real part)

# calculate dxy / dC
# for complex scalars, grad is wrt the real part
xy.backward()
print(C.grad)
# tensor([['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)'],
#         ['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)']])

supports all section ops...

print(C[-1])
print(C[0, 0:-2, ...])
print(C[0, ..., 0])

Supported ops:

Operation complex tensor real tensor complex scalar real scalar
addition Y Y Y Y
subtraction Y Y Y Y
multiply Y Y Y Y
mm Y Y Y Y
abs Y - - -
t Y - - -
grads Y Y Y Y