wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
610 stars 148 forks source link

complex batch norm2d track_running_stats=False produce error #6

Closed nisanaryal closed 3 years ago

nisanaryal commented 3 years ago

(1 - exponential_average_factor) self.running_mean TypeError: unsupported operand type(s) for : 'float' and 'NoneType'

There is a bug in batch normalization, when we initialize with track_running_stats=False , the parameters arent initialized(as expected) but in training it tries to update running mean and running covariance.

error occurs at this line self.running_mean = exponential_average_factor * mean \

wavefrontshaping commented 3 years ago

Indeed, track_running_stats was not checked correctly for eval mode. Should be good now.