Open King-Of-Knights opened 2 years ago
Indeed, the Complex BatchNorm is not optimized and is not previewed to be optimized in the short term. I am sorry for the trouble caused. The reason is similar as what happens with ComplexPyTorch.
I was having the same problem and came up with this simple solution. According to the authors of ComplexPyTorch performing batch nomalization in a 'naive' way i.e. separately on the real and imaginary parts does not have a significant impact when compared to the complex formulation of Trabelsi et al.
Here's a TF version of their NaiveComplexBatchNorm layer, which can be used with the keras functional API.
import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization
def naive_complex_batch_normalization(inputs: tf.Tensor) -> tf.Tensor:
real, imag = tf.cast(tf.math.real(inputs), tf.float32), tf.cast(tf.math.imag(inputs), tf.float32)
real_bn, imag_bn = BatchNormalization()(real), BatchNormalization()(imag)
return tf.cast(tf.complex(real_bn, imag_bn), tf.complex64)
@NEGU93, would you be interested in a PR implementing this as a proper tf.keras.layers.Layer
class?
Sure, not sure what they are based on to guarantee that, from my point of view, doing a naive implementation may have a very negative impact on the phase, which is a crucial aspect of CVNN merits Ref. But well, using CReLU should have a similar impact, and it still works well, so... Why not?
Please, submit your PR! and thank you for the contribution!
Here is an implementation of a small 1D CNN for example until that PR would be integrated into the cvnn package:
def get_model(input_len=1000, activation_func='crelu'): inputs = layers.complex_input(shape=(input_len, 1)) conv0 = layers.ComplexConv1D(64, 7, input_shape=(input_len, 1), activation=activation_func)(inputs) bn_r0 = keras.layers.BatchNormalization()(tf.cast(tf.math.real(conv0), tf.float32)) bn_i0 = keras.layers.BatchNormalization()(tf.cast(tf.math.imag(conv0), tf.float32)) p0 = layers.ComplexAvgPooling1D(pool_size=2)(tf.cast(tf.complex(bn_r0, bn_i0), tf.complex64)) out = layers.ComplexConv1D(32, 3, activation=activation_func)(p0) return tf.keras.Model(inputs, out)
Hi there, @NEGU93. Thanks for the great effort in making this library. It really accelerate my research in signal recognition task. This TF 2.0 version indeed help me deploy in the edge device with the help of TFlite. However, I found
ComplexBatchNormalization()
will terribly slow down the training process. Give one example to reproduce:It almost cost me 10 mins to train one epoch. But, when I substitute
ComplexBatchNormalization()
toBatchNormalization()
, it only costs me half min. Any ideas?