NEGU93 / cvnn

Library to help implement a complex-valued neural network (cvnn) using tensorflow as back-end
https://complex-valued-neural-networks.readthedocs.io/
MIT License
160 stars 33 forks source link

Terrible slow caused by ComplexBatchNormalization() #30

Open King-Of-Knights opened 2 years ago

King-Of-Knights commented 2 years ago

Hi there, @NEGU93. Thanks for the great effort in making this library. It really accelerate my research in signal recognition task. This TF 2.0 version indeed help me deploy in the edge device with the help of TFlite. However, I found ComplexBatchNormalization() will terribly slow down the training process. Give one example to reproduce:

import numpy as np
from tensorflow.keras.models import Model
import tensorflow
from cvnn.layers import ComplexConv1D, ComplexInput, ComplexDense, ComplexBatchNormalization, ComplexFlatten, complex_input

X_train = np.random.rand(18000, 4096, 2)
Y_train = np.random.randint(0, 9, 18000)
X_test = np.random.rand(2000, 4096, 2)
Y_test = np.random.randint(0, 9, 2000)

inputs = complex_input(shape=X_train.shape[1:])
outs = inputs
outs = (ComplexConv1D(16, 6, strides=1, padding='same', activation='cart_relu'))(outs)
outs = (ComplexBatchNormalization())(outs)

outs = (ComplexConv1D(32, 3, strides=1, padding='same', activation='cart_relu'))(outs)
outs = (ComplexBatchNormalization())(outs)

outs = (ComplexFlatten())(outs)
DL_feature = (ComplexDense(128, activation='cart_relu'))(outs)
outs = (ComplexDense(256, activation='cart_relu'))(DL_feature)
outs = (ComplexDense(256, activation='cart_relu'))(outs)
predictions = (ComplexDense(, activation='cast_to_real'))(outs)

model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer=tensorflow.keras.optimizers.Adam(learning_rate=1e-4),
              loss=tensorflow.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(X_train, Y_train, batch_size=32, epochs=3, verbose=1, validation_data=(X_test, Y_test),
                    callbacks=[checkpoint, earlystopping, learn_rate])

It almost cost me 10 mins to train one epoch. But, when I substituteComplexBatchNormalization() to BatchNormalization(), it only costs me half min. Any ideas?

NEGU93 commented 1 year ago

Indeed, the Complex BatchNorm is not optimized and is not previewed to be optimized in the short term. I am sorry for the trouble caused. The reason is similar as what happens with ComplexPyTorch.

jollyjonson commented 1 year ago

I was having the same problem and came up with this simple solution. According to the authors of ComplexPyTorch performing batch nomalization in a 'naive' way i.e. separately on the real and imaginary parts does not have a significant impact when compared to the complex formulation of Trabelsi et al.

Here's a TF version of their NaiveComplexBatchNorm layer, which can be used with the keras functional API.

import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization

def naive_complex_batch_normalization(inputs: tf.Tensor) -> tf.Tensor:
    real, imag = tf.cast(tf.math.real(inputs), tf.float32), tf.cast(tf.math.imag(inputs), tf.float32)
    real_bn, imag_bn = BatchNormalization()(real), BatchNormalization()(imag)
    return tf.cast(tf.complex(real_bn, imag_bn), tf.complex64)

@NEGU93, would you be interested in a PR implementing this as a proper tf.keras.layers.Layerclass?

NEGU93 commented 1 year ago

Sure, not sure what they are based on to guarantee that, from my point of view, doing a naive implementation may have a very negative impact on the phase, which is a crucial aspect of CVNN merits Ref. But well, using CReLU should have a similar impact, and it still works well, so... Why not?

Please, submit your PR! and thank you for the contribution!

maorgranot1 commented 1 year ago

Here is an implementation of a small 1D CNN for example until that PR would be integrated into the cvnn package:

def get_model(input_len=1000, activation_func='crelu'): inputs = layers.complex_input(shape=(input_len, 1)) conv0 = layers.ComplexConv1D(64, 7, input_shape=(input_len, 1), activation=activation_func)(inputs) bn_r0 = keras.layers.BatchNormalization()(tf.cast(tf.math.real(conv0), tf.float32)) bn_i0 = keras.layers.BatchNormalization()(tf.cast(tf.math.imag(conv0), tf.float32)) p0 = layers.ComplexAvgPooling1D(pool_size=2)(tf.cast(tf.complex(bn_r0, bn_i0), tf.complex64)) out = layers.ComplexConv1D(32, 3, activation=activation_func)(p0) return tf.keras.Model(inputs, out)