wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
610 stars 148 forks source link

grad computations #3

Closed xmax1 closed 11 months ago

xmax1 commented 4 years ago

hey and thanks for the great work,

Just a question on the grad computations (though this could be my lack of understanding on autograd complex analysis).

For the forward pass you have correctly, for example in

class ComplexLinear(Module):

def __init__(self, in_features, out_features):
    super(ComplexLinear, self).__init__()
    self.fc_r = Linear(in_features, out_features)
    self.fc_i = Linear(in_features, out_features)

def forward(self,input_r, input_i):
    return self.fc_r(input_r)-self.fc_i(input_i), \
           self.fc_r(input_i)+self.fc_i(input_r)

That the forward pass computes Re(x) Re(w) - Im(x) Im(w) + i ( Re(x) Im(w) + Im(x) Re(w) ), which is good... Then we want to compute the gradients of the weights.

Via complex differentiation we find f(w) = w*x w,x \in Complex df / dw = x

(which holds under cauchy-riemann) but when I find the derivatives of the real/imaginary components of the network I get different answers to the soln above (code below)

`
import torch as tc
import numpy as np
from torch import nn

class ComplexLinear(nn.Module):

    def __init__(self, in_features, out_features):
        super(ComplexLinear, self).__init__()
        self.fc_r = nn.Linear(in_features, out_features)
        self.fc_i = nn.Linear(in_features, out_features)

    def forward(self, input_r, input_i):
        return self.fc_r(input_r)-self.fc_i(input_i), \
               self.fc_r(input_i)+self.fc_i(input_r)

class Net(nn.Module):
    def __init__(self, n_input, n_output):
        super(Net, self).__init__()

        self.layer = ComplexLinear(n_input, n_output)

    def forward(self, x):
        xr = x[:,:,0]
        xi = x[:,:,1]
        out = self.layer(xr, xi)
        return out

X = tc.tensor([[[-0.16,0.51]]], requires_grad=True)

# First dim # datapoints, second dim # variables, third dim complex component
net = Net(X.shape[1],1)

Y = net(X)
Y[0].backward()
Y[1].backward()
print(net.layer.fc_r.weight.grad)
print(net.layer.fc_i.weight.grad)
`

returns tensor([[0.3500]]) tensor([[-0.6700]])

whereas I think the answer should be X, which is tensor([[-0.16]]) tensor([[0.51]])

... Really looking forward to your answer and happy to contribute, I'm going to be working on complex graph neural networks and there doesn't seem to be anything else online about this so could integrate this with your package?

Max

wavefrontshaping commented 4 years ago

Hi,

Thanks for the comment.

To be honest, I did not pay attention to the gradients and just let the autograd do its job. It seems that indeed, the results are different from the standard complex derivative. It would then be necessary to implement the backward() function by hand. As it may be slow then, it is probably best to use the jit capabilities of PyTorch (https://pytorch.org/docs/stable/jit.html).

Unfortunately, I have no time right now to dig into it, but I need to look into it when I can,

nitinnilesh commented 3 years ago

Hi,

Great work there. I have redefined just the Linear layer by converting the weights to complex type. I am using torch.matmul to multiply the weights and inputs. Please take a look. Request @xmax1 to check this implementation in terms of gradient calculation. Thanks!

class ComplexLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(ComplexLinear, self).__init__()
        self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=torch.cfloat))
        self.bias = nn.Parameter(torch.randn(out_features, dtype=torch.cfloat))

    def forward(self, x):
        return torch.matmul(x, self.weight) + self.bias
wavefrontshaping commented 2 years ago

Hi,

If it works now it is great. As some point, it was not possible to use complex weights for some reasons, either because of the autograd did not work with them or it did not work on GPUs (not sure now). Does it work fine this way? I will perform some tests on my side.