wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
610 stars 148 forks source link

Problem with ComplexBatchNorm2d #8

Closed ninfueng closed 3 years ago

ninfueng commented 3 years ago

Hello,

I have some issues during testing or model.eval() with ComplexBatchNorm2d (the training is fine). The problem line is:

input = input - mean[None, :, None, None]

The problem seems to be the shape difference between mean and input. In this case, the mean shape is [1000, 64, 32, 32] and input shape is [64, 64]. Note that, I used VGG16 like model with CIFAR10 dataset.

With the same model, NaiveComplexBatchNorm2d is working fine.

wavefrontshaping commented 3 years ago

Bug appeared when I switched to complex tensors and failed to modified correctly the shape of the running_mean parameter. Should be fixed now.