wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
610 stars 148 forks source link

ComplexBatchNorm1d Error #20

Open Metamorphosis-chm opened 2 years ago

Metamorphosis-chm commented 2 years ago

File "D:/Pycharm/coplexcnn/train.py", line 124, in y_hat = net(X) File "C:\Users\MyPC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, *kwargs) File "D:/Pycharm/complexcnn/train.py", line 79, in forward x = self.bn1(x) File "C:\Users\MyPC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(input, *kwargs) File "D:\Pycharm\complexcnn\complexLayers.py", line 294, in forward self.running_mean = exponential_average_factor mean \ RuntimeError: The size of tensor a (253) must match the size of tensor b (32) at non-singleton dimension 1

How to solve it?Thank

Glen9010 commented 1 year ago

maybe you can try again, the code has updated 3 months ago. Beside, there still exit memory leakage In ComplexBatchNorm1d. You can add "with torch.no_grad():" after code line 254

saugatkandel commented 1 year ago

In case someone is reading this, I had to change the Batchnorm code slightly to make it work properly. Here are my changes to get Batchnorm2d working (Batchnorm1d is similar). The formatting is a bit weird because I use Black as my default linter.


class _ComplexBatchNorm(Module):
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,
        affine=True,
        track_running_stats=True,
    ):
        super(_ComplexBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features, 3))
            self.bias = Parameter(torch.Tensor(num_features, 2))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        if self.track_running_stats:
            self.register_buffer("running_mean_r", torch.zeros(num_features))
            self.register_buffer("running_mean_i", torch.zeros(num_features))
            self.register_buffer("running_covar", torch.zeros(num_features, 3))
            self.running_covar[:, 0] = 1 / 1.4142135623730951
            self.running_covar[:, 1] = 1 / 1.4142135623730951
            self.register_buffer(
                "num_batches_tracked", torch.tensor(0, dtype=torch.long)
            )
        else:
            self.register_parameter("running_mean_r", None)
            self.register_parameter("running_mean_i", None)
            self.register_parameter("running_covar", None)
            self.register_parameter("num_batches_tracked", None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean_r.zero_()
            self.running_mean_i.zero_()
            self.running_covar.zero_()
            self.running_covar[:, :2] = 1 / 1.4142135623730951
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            init.constant_(self.weight[:, :2], 1 / 1.4142135623730951)
            init.zeros_(self.weight[:, 2])
            init.zeros_(self.bias)

class ComplexBatchNorm2d(_ComplexBatchNorm):
    def forward(self, inputs):
        exponential_average_factor = 0.0 if self.momentum is None else self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        if self.training or (not self.training and not self.track_running_stats):
            # calculate mean of real and imaginary part
            # mean does not support automatic differentiation for outputs with complex dtype.

            mean_r = inputs.real.mean([0, 2, 3])
            mean_i = inputs.imag.mean([0, 2, 3])
        else:
            mean_r = self.running_mean_r.clone()
            mean_i = self.running_mean_i.clone()

        if self.training and self.track_running_stats:
            # update running mean
            with torch.no_grad():

                self.running_mean_r[:] = (
                    exponential_average_factor * mean_r
                    + (1 - exponential_average_factor) * self.running_mean_r
                )
                self.running_mean_i[:] = (
                    exponential_average_factor * mean_i
                    + (1 - exponential_average_factor) * self.running_mean_i
                )

        inputs = inputs - (mean_r + 1j * mean_i)[None, :, None, None]

        if self.training or (not self.training and not self.track_running_stats):
            # Elements of the covariance matrix (biased for train)

            # n = input.numel() / input.size(1)
            Crr = inputs.real.pow(2).mean(dim=[0, 2, 3]) + self.eps
            Cii = inputs.imag.pow(2).mean(dim=[0, 2, 3]) + self.eps
            Cri = (inputs.real * inputs.imag).mean(dim=[0, 2, 3])
        else:
            Crr = self.running_covar[:, 0] + self.eps
            Cii = self.running_covar[:, 1] + self.eps
            Cri = self.running_covar[:, 2]  # +self.eps

        if self.training and self.track_running_stats:
            with torch.no_grad():
                self.running_covar[:, 0] = (
                    exponential_average_factor * Crr
                    + (1 - exponential_average_factor) * self.running_covar[:, 0]
                )

                self.running_covar[:, 1] = (
                    exponential_average_factor * Cii
                    + (1 - exponential_average_factor) * self.running_covar[:, 1]
                )

                self.running_covar[:, 2] = (
                    exponential_average_factor * Cri
                    + (1 - exponential_average_factor) * self.running_covar[:, 2]
                )

        # calculate the inverse square root the covariance matrix
        det = Crr * Cii - Cri.pow(2)

        s = torch.sqrt(det)
        t = torch.sqrt(Cii + Crr + 2 * s)
        inverse_st = 1.0 / (s * t)
        Rrr = (Cii + s) * inverse_st
        Rii = (Crr + s) * inverse_st
        Rri = -Cri * inverse_st

        inputs = (
            Rrr[None, :, None, None] * inputs.real
            + Rri[None, :, None, None] * inputs.imag
        ).type(torch.complex64) + 1j * (
            Rii[None, :, None, None] * inputs.imag
            + Rri[None, :, None, None] * inputs.real
        ).type(
            torch.complex64
        )

        if self.affine:
            inputs = (
                self.weight[None, :, 0, None, None] * inputs.real
                + self.weight[None, :, 2, None, None] * inputs.imag
                + self.bias[None, :, 0, None, None]
            ).type(torch.complex64) + 1j * (
                self.weight[None, :, 2, None, None] * inputs.real
                + self.weight[None, :, 1, None, None] * inputs.imag
                + self.bias[None, :, 1, None, None]
            ).type(
                torch.complex64
            )

        return inputs

The exact changes are as follows:

  1. From my reading of the linked paper and the associated code (https://github.com/ChihebTrabelsi/deep_complex_networks), the running_covar and the weight initialization should be initialized to 1/sqrt(2) and not sqrt(2).
  2. The running mean buffer registration and calculation. The buffer assignment was the hardest to figure out, in that assigning self.running_mean_r = ... in the forward step does not work, but self.running_mean_r[:] = ... works. Something to do with the Pytorch internals, I guess.
karli262 commented 11 months ago

Hello, I had the same problem as you described. I noticed that the "ComplexBatchNorm2d" function, which is designed for 4D data (N, C, H, W), calculates the mean and the variance over 3 dimensions e.g. mean_r = input.real.mean([0, 2, 3]) and also applies those parameters in a similar fashion e.g. input = input - mean[None, :, None, None].

The function "ComplexBatchNorm1d" does this only for 2D data and my data was 3d (N, C, L) which caused the problem in my case. To given an example: I changed mean_r = input.real.mean(dim=0).type(torch.complex64) to mean_r = input.real.mean([0, 2]).type(torch.complex64). When applying the mean I changed input = input - mean[None, ...] to input = input - mean[None, :, None]. By doing this also for the imaginary parts and the covariance I obtained the normalized output for 3D data. I hope this is helpful.