NEGU93 / cvnn

Library to help implement a complex-valued neural network (cvnn) using tensorflow as back-end
https://complex-valued-neural-networks.readthedocs.io/
MIT License
160 stars 33 forks source link

Casting Error complex -> float during learning only. #52

Open rjdw opened 2 months ago

rjdw commented 2 months ago

Using example from Documentation Example

Model after initialization correctly takes complex input and outputs complex.

However, during training warning, You are casting an input of type complex64 to an incompatible dtype float32. This will discard the imaginary part and may not be what you intended.

During training, model is unable to output complex value. Output contains real part only.

I am unable to determine the origin of this casting error.

As per TF 2.16 Incompatability I am using:

tensorflow                   2.15.0
cvnn                         1.2.22

MWE:

import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from keras.initializers import HeNormal, GlorotUniform
import cvnn.layers as complex_layers
import cvnn.initializers as complex_initializers
## Custom Loss Function
def custom_loss(k):
    def loss(y_true, y_pred):
        # Extract real (a) and imaginary (b) parts
        a = tf.math.real(y_pred)
        b = tf.math.imag(y_pred)

        # Main loss: (y - a)^2 + b^2
        mse = tf.reduce_mean(tf.square(y_true - a) + tf.square(b))

        # Regularization term: (e^(-k * a) - 1) if a < 0, else 0
        condition = tf.less(a, 0)
        regularization_term = tf.where(condition, tf.exp(-k * a) - 1, tf.zeros_like(a))
        regularization_term = tf.reduce_mean(regularization_term)

        # Total loss
        total_loss = mse + regularization_term
        return total_loss
    return loss

# Custom MAE metric
# abs(a - y_true)
def custom_mae(y_true, y_pred):
    a = tf.math.real(y_pred)
    mae = tf.reduce_mean(tf.abs(y_true - a))
    return mae

# Printing b_value metric
def b_value(y_true, y_pred):
    b = tf.math.imag(y_pred)
    b_val = tf.reduce_mean(tf.abs(b))
    return b_val
# Example from Docs
test_model = Sequential()
test_model.add(complex_layers.ComplexInput(input_shape=num_features, dtype=np.complex64))
test_model.add(complex_layers.ComplexFlatten())
test_model.add(complex_layers.ComplexDense(32, activation='cart_relu', dtype=np.complex64))
test_model.add(complex_layers.ComplexDense(1, dtype=np.complex64))
print(test_model.output_shape)
# Showing that the model correctly takes complex input and outputs complex value
x = tf.cast(tf.random.normal((1, 8 ,148)), tf.complex64)
x.shape, x.dtype

(TensorShape([1, 8, 148]), tf.complex64)

out = test_model(x)
out.dtype, out

(tf.complex64, <tf.Tensor: shape=(1, 1), dtype=complex64, numpy=array([[1.0031033+0.26049298j]], dtype=complex64)>)

k = 0.001

# Compile the model
test_model.compile(optimizer='adam', loss=custom_loss(k), metrics=[custom_mae, b_value])
# Sim Data
import numpy as np

# Generate 100 complex-valued 2D arrays of shape (8, 148)
data = np.random.randn(100, 8, 148) + 1j * np.random.randn(100, 8, 148)
data = data.astype(np.complex64)

# Generate 100 labels of dtype float32
labels = np.random.randn(100).astype(np.float32)

# Split data and labels into training (80%) and validation (20%) sets
train_size = int(0.8 * len(data))

t_d = data[:train_size]
t_l = labels[:train_size]

v_d = data[train_size:]
v_l = labels[train_size:]

t_d.shape, t_l.shape, v_d.shape, v_l.shape
# Data is correctly typed
for data in t_d:
    if data.dtype != 'complex64':
        print(data.dtype)
        print(exit)
t_l[0].dtype, t_d[0].real.dtype

(dtype('float32'), dtype('float32'))

from keras.callbacks import Callback

# Callback for seeing the b value in the complex output (a + ib)
class PrintBValueCallback(Callback):
    def on_train_batch_end(self, batch, logs=None):
        b_value = logs.get('b_value')
        print(f'Batch {batch}, b_value: {b_value}')
# Train the model
history = test_model.fit(t_d, t_l, 
                    epochs=32, batch_size=16, 
                    shuffle=True, 
                    callbacks=[PrintBValueCallback()],
                    validation_data=(v_d, v_l))

complex_error_cast

First training batch: phase value is already 0. No valid complex value is being outputted. There is a float that is being cast back into complex. The complex output has no imag part.

Edit: Changed dtype from complex128 to complex64. Result is same casting error.

NEGU93 commented 2 months ago

I think the default dtype is complex64 and not complex128. Can you try using all complex64 and see if it works?

rjdw commented 2 months ago

Hi, sorry for late response.

I tried to use all complex64, but the same casting error occurs. I've edited the above MRE to reflect this.