keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
53 stars 26 forks source link

Complex-valued variance scaling initializers #111

Open jmontalt opened 2 years ago

jmontalt commented 2 years ago

System information.

TensorFlow version (you are using): 2.8 Are you willing to contribute it (Yes/No) : Yes

Describe the feature and the current behavior/state.

The VarianceScaling initializer and its subclasses do not currently support complex dtypes. Initialization of complex-valued networks was described in this paper, which also contains an introduction to them and some of their potential uses.

The proposed feature would be a prerequisite to defining complex-valued layers such as Dense or Conv.

Will this change the current api? How?

No. VarianceScaling and its subclasses will simply return complex values when passed a complex dtype, instead of raising an error.

Who will benefit from this feature?

Anyone who would like to use complex-valued networks, with applications including speech processing and medical imaging.

Contributing

Augment the VarianceScaling class to implement the complex-valued normal/uniform distributions using the ideas in this paper. This would be similar to the recent implementation in JAX.

qlzh727 commented 2 years ago

Triage notes: Not sure if we have proper complex value support across the keras stack (probably need to rely on tf ops support for complex value). We need to see how much work we need for this.

jmontalt commented 2 years ago

I think that, if we limit the discussion to the variance scaling initializers (and perhaps the standard random initializers), the necessary amount of work is not overwhelming. This is because the complex number generation can be easily be expressed using the existing real-valued keras.backend.RandomGenerator.

If it helps, I have implemented them in an external package, and you can find the code here. Note that much of the code there is unchanged from the Keras package. Also there's some wrapping logic which obviously wouldn't be necessary if implemented directly in Keras.

Naturally, many other building blocks are necessary to build complex-valued neural networks (e.g. dense, conv, pooling, activations, etc). Each of these components might need a bit of work if Keras intended to support complex-valued networks out-of-the-box. However, in many cases these can be easily expressed as a combination of real-valued operators without loss of generality or performance. This functionality could be built over time (I'd be happy to help!).