wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
610 stars 148 forks source link

UserWarning: Casting complex values to real discards the imaginary part Issue #10

Closed Zongyang-Li closed 2 years ago

Zongyang-Li commented 3 years ago

Hi, When I am running the example code, it always raises the complex value warning. I believe that it effectively affect the final performance of the network because only half of the complex data were considered during the training. Is there anyway that we could fix the issue?

Thank you.

wavefrontshaping commented 3 years ago

Ok, there were two things:

First, in the functions complex_dropout and complex_dropout2d, I created a mask that would be applied to both the real and imaginary parts to set some elements to zero. To do so, I first created a complex tensor full of ones, cast it to float, and used the real function dropout to create the mask. The casting from complex to real triggered the warning but was harmless as there was no imaginary part. Anyway, it was awkward, I changed it.

Secondly, there is still a warning triggered during backpropagation. If you create a model with only a complex convolutional layer for instance, you have this warning. What I understand is that, because under the hood I use two real convolutional filters for the real and imaginary parts while the input of the module is complex, at some point, it cast complex to real types. That does not mean it is wrong, as the gradient for each part should be real. There is probably a way to understand better how it works but the best way is simply to test the gradients with gradcheck. gradcheck actually works with complex inputs, it tests gradient in both directions in the complex plane when using complex inputs, and everything seems to work. You could test ComplexConv2d and ComplexLinear this way:

conv = ComplexConv2d(in_channels = 1,
                     out_channels = 1,
                     kernel_size = 2)
input = (torch.randn(1,1,3,3,dtype=torch.complex64,requires_grad=True))
print(gradcheck(conv, input, eps=1e-4, atol=1e-3))
>>> True

similarly

fc = ComplexLinear(in_features = 5, out_features = 5)
input = (torch.randn(1,1,5,dtype=torch.complex64,requires_grad=True))
print(gradcheck(fc, input, eps=1e-4, atol=1e-3))
>>> True