NEGU93 / cvnn

Library to help implement a complex-valued neural network (cvnn) using tensorflow as back-end
https://complex-valued-neural-networks.readthedocs.io/
MIT License
164 stars 34 forks source link

Model subclassing compatibility #26

Open lminer opened 2 years ago

lminer commented 2 years ago

I've been trying to get this to work with the model subclassing API, but for some reason, the first layer of the model always expects the data to be in float32. Any idea how to get this to work?

NEGU93 commented 2 years ago

Did you add the ComplexInput layer at the start? If not TF will automatically use the tensorflow Input and cast it to float.

lminer commented 2 years ago

How would you do that in a subclassed model? I have a custom train and test step so I can't just do:

inputs = ComplexInput((1,2,3))
outputs = SubClassedModel(inputs)
model = tf.keras.Model(inputs, outputs)
NEGU93 commented 2 years ago

It is possible... did it work? I have never worked with SubCalssedModels, I don't know how they worked. What is the difference from a normal model?

lminer commented 2 years ago

I couldn't get it to work. In order to use the approach above, I would have needed too big a refactor. Subclassed models are just a pytorch like interface. You just inherit from the normal model and then build the layers in the constructor, and the implementation in call. It's the same as a custom layer, but with a model instead.