keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.64k stars 19.42k forks source link

[Feature request] Complex dtype input support for layers #19860

Open refraction-ray opened 3 months ago

refraction-ray commented 3 months ago

See the original issue in tensorflow's repo: https://github.com/tensorflow/tensorflow/issues/65306

Current behavior?

The following code works well for tf2.15 but fails tf2.16.1 with the introduction of keras3,

import tensorflow as tf
print(tf.__version__)

class MyDenseLayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs):
        super(MyDenseLayer, self).__init__()
        self.num_outputs = num_outputs

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=[int(input_shape[-1]),
                                         self.num_outputs], )

    def call(self, inputs):
        kernel = tf.cast(self.kernel, tf.complex64)
        return tf.matmul(inputs, kernel)

layer = MyDenseLayer(10)
layer(tf.zeros([10, 5], dtype=tf.complex64))

For short, tf2.16.1 with newly introduced keras 3 seems not supporting the input tensor for a layer in complex dtype

Relevant log output

[usr/local/lib/python3.10/dist-packages/keras/src/backend/common/variables.py](https://localhost:8080/#) in standardize_dtype(dtype)
    428 
    429     if dtype not in dtypes.ALLOWED_DTYPES:
--> 430         raise ValueError(f"Invalid dtype: {dtype}")
    431     return dtype
    432 

ValueError: Invalid dtype: complex64

It would be better that keras3 can support complex valued input for layers as keras did before. Complex valued input is very common in quantum machine learning use cases.

mehtamansi29 commented 3 months ago

Hi @refraction-ray -

We have raised an internal fix for this issue. Soon it will be reflected in the documentation as well. Thanks.