wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
623 stars 148 forks source link

ConvTranspose2d error #14

Closed Dub21 closed 3 years ago

Dub21 commented 3 years ago

Hello, thanks for your work. I am trying to use it for image to image translation but I got the following error while using the ConvTranspose2d error : RuntimeError: Input type (CUDAComplexFloatType) and weight type (torch.cuda.FloatTensor) should be the same Have you been able to successfully use ConvTranspose2d?

Thanks

wavefrontshaping commented 3 years ago

Hi, Can you please provide a minimal example to reproduce the error? Thx

Dub21 commented 3 years ago

class ComplexNet_Unet(nn.Module):

def __init__(self):
    super(ComplexNet_Unet, self).__init__()
    self.conv1 = ComplexConv2d(1, 8, 5, 2, 2)

    self.conv2= ConvTranspose2d(8, 1, 1,1)

def forward(self,x):
    x = self.conv1(x)
    x = self.conv2(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ComplexNet_Unet().to(device) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

loss = nn.L1Loss()

def train(model, device, data,target, optimizer, epochs): model.train() for epoch in range(epochs):

    #data = torch.complex(torch.DoubleTensor(np.real(X)),torch.DoubleTensor(np.imag(X)))
    #target = torch.complex(torch.DoubleTensor(np.real(y)),torch.DoubleTensor(np.imag(y)))

    #data, target = data2.to(device).type(torch.complex64), target2.to(device)
    optimizer.zero_grad()
    output = model(data)
    print(output.shape)
    value = loss(output, target)
    print(value)
    value.backward()
    optimizer.step()
    if batch_idx % 100 == 0:
        print('Train Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\tLoss: {:.6f}'.format(
            epoch,
            batch_idx * len(data), 
            len(train_loader.dataset),
            100. * batch_idx / len(train_loader), 
            loss.item())
        )

Run training on 50 epochs

for epoch in range(1): train(model, device, data, target, optimizer, 500)

wavefrontshaping commented 3 years ago

That is not so minimal...

wavefrontshaping commented 3 years ago

Before testing anything, I see that you use the native ConvTranspose2d instead of the complex counterpart ComplexConvTranspose2d from complexPyTorch. Is that on purpose?