josiahwsmith10 / complextorch

GNU General Public License v3.0
19 stars 5 forks source link

Gradient with CVConv2d not able to compute #5

Open rachelglenn opened 3 months ago

rachelglenn commented 3 months ago

I am trying to define a 2D model with the CVConv2d function. I am getting an error when I calculate the gradient. I decided to define a network with a single CVConv2d layer and I am still not able to calculate the gradient

class testModel(nn.Module):
    def __init__(
        self,
        in_chans: int,
        out_chans: int,
        chans: int = 32,
        num_pool_layers: int = 4,
        drop_prob: float = 0.0,
    ):

        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.chans = chans
        self.num_pool_layers = num_pool_layers
        self.drop_prob = drop_prob
        self.layers = nn.Sequential(
            CVConv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, bias=False),
            CVBatchNorm2d(out_chans),
            # CVSplitTanh(),
            # CVDropout(drop_prob),
            # CVConv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3,  bias=False),
            # CVBatchNorm2d(out_chans),
            # CVSplitTanh(),
            # CVDropout(drop_prob),
        )

    def forward(self, image):

        block = ConvBlock(self.in_chans, self.chans, self.drop_prob)
        output = self.layers(image)

        return output

net = testModel( in_chans =1,
        out_chans = 1,
        chans = 32,
        num_pool_layers = 4,
        drop_prob = 0.2,).to(dev)

size_image = 128
img = torch.rand((1, 1, size_image, size_image), dtype = torch.complex64).to(dev)
img = CVTensor(img.real, img.imag) 
out = net(img)
#out = torch.complex(out.real, out.imag) 
target =  torch.rand(out.shape, dtype = torch.complex64).to(dev)

loss = F.binary_cross_entropy_with_logits(out.real, target.real) + F.binary_cross_entropy_with_logits(out.imag, target.imag)
loss.backward()
net.zero_grad()

I get the error: RuntimeError: output with shape [2, 1] doesn't match the broadcast shape [2, 2]