NEGU93 / cvnn

Library to help implement a complex-valued neural network (cvnn) using tensorflow as back-end
https://complex-valued-neural-networks.readthedocs.io/
MIT License
164 stars 34 forks source link

ComplexConv2D with bias vector slows down training a lot #34

Open osannolik opened 1 year ago

osannolik commented 1 year ago

If use_bias is set to True with ComplexConv2D (i.e. the default) the required time to train will be 2-3 times longer than if no bias is used. With keras real-valued Conv2D the difference due to bias is basically none.

Is that to be expected?

I observed that cpu-usage is a bit higher when enabling bias, but allocated gpu-memory is the same for both cases.

See below for a simple example.

import numpy as np
import tensorflow as tf
import cvnn.layers

n_samples = 10000
data_shape = (n_samples, 128, 256, 2)
input_shape = data_shape[1:]

data = np.csingle(np.random.rand(*data_shape) + 1j*np.random.rand(*data_shape))
labels = np.float32(np.random.rand(n_samples))

use_bias = True # True increases train time by a factor 2-3

model = tf.keras.models.Sequential([
    cvnn.layers.ComplexInput(input_shape=input_shape, dtype=np.complex64),
    cvnn.layers.ComplexConv2D(8, (5, 5), activation='cart_relu', use_bias=use_bias),
    cvnn.layers.ComplexFlatten(),
    cvnn.layers.ComplexDense(1, activation='convert_to_real_with_abs')
])

print("Total size: {} MB".format((data.nbytes+labels.nbytes)/1_000_000))

model.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-04), loss='mean_squared_error', metrics=[tf.keras.metrics.RootMeanSquaredError()])

model.summary()

model.fit(data, labels, epochs=5, verbose=2)
NEGU93 commented 1 year ago

Hello @osannolik, and thank you for the bug report.

This is an interesting insight. I was not expecting this to happen. Here is how the bias is added. Indeed it seems like too much code for something so simple. I also could optimize conv2D using gauss trick to do 3 convolutions instead of 4.

Due to the holidays, I won't sit with this right now. I might fix it in January, If you want, I invite you to also do a pull request if you need it ASAP.

osannolik commented 1 year ago

It would be nice to have it fixed as it is a pretty good optimization that would benefit more or less all users of ComplexConv2D. Not sure if I have the time to fix it myself in the near future unfortunately.

Lunariz commented 1 year ago

Any updates on this? I recently started using ComplexConv2D and saw my training time slow down by a factor of 3~4.

I haven't tried enabling or disabling the bias specifically, but I can give that a try to see if that is the cause.

NEGU93 commented 1 year ago

Are you using batch norm? That is the main speed issue with my library.

Lunariz commented 1 year ago

No, I'm not using any batch norm or layer norm at all!

NEGU93 commented 1 year ago

Unfortunately, I now work a full time engineer and no longer have time to keep developing/upgrading this library. I am sorry for those users but I don't think this feature is coming in the near future.

NEGU93 commented 1 year ago

I've been told that dividing by just the norm of the complex value would be good enough and it's quite fast to compute. I will leave it as an option in this issue.