Open jmontalt opened 2 years ago
Triage notes: Not sure if we have proper complex value support across the keras stack (probably need to rely on tf ops support for complex value). We need to see how much work we need for this.
I think that, if we limit the discussion to the variance scaling initializers (and perhaps the standard random initializers), the necessary amount of work is not overwhelming. This is because the complex number generation can be easily be expressed using the existing real-valued keras.backend.RandomGenerator
.
If it helps, I have implemented them in an external package, and you can find the code here. Note that much of the code there is unchanged from the Keras package. Also there's some wrapping logic which obviously wouldn't be necessary if implemented directly in Keras.
Naturally, many other building blocks are necessary to build complex-valued neural networks (e.g. dense, conv, pooling, activations, etc). Each of these components might need a bit of work if Keras intended to support complex-valued networks out-of-the-box. However, in many cases these can be easily expressed as a combination of real-valued operators without loss of generality or performance. This functionality could be built over time (I'd be happy to help!).
System information.
TensorFlow version (you are using): 2.8 Are you willing to contribute it (Yes/No) : Yes
Describe the feature and the current behavior/state.
The
VarianceScaling
initializer and its subclasses do not currently support complex dtypes. Initialization of complex-valued networks was described in this paper, which also contains an introduction to them and some of their potential uses.The proposed feature would be a prerequisite to defining complex-valued layers such as
Dense
orConv
.Will this change the current api? How?
No.
VarianceScaling
and its subclasses will simply return complex values when passed a complex dtype, instead of raising an error.Who will benefit from this feature?
Anyone who would like to use complex-valued networks, with applications including speech processing and medical imaging.
Contributing
Augment the
VarianceScaling
class to implement the complex-valued normal/uniform distributions using the ideas in this paper. This would be similar to the recent implementation in JAX.