NEGU93 / cvnn

Library to help implement a complex-valued neural network (cvnn) using tensorflow as back-end
https://complex-valued-neural-networks.readthedocs.io/
MIT License
160 stars 33 forks source link

load CVNN model with succes #27

Closed Aminehlel closed 2 years ago

Aminehlel commented 2 years ago

For the save and the load of a CVNN model i use :

I attach an example for more explanation

1) For save 👍

#Build of the model

tf.random.set_seed(1)
init = cvnn.initializers.ComplexGlorotUniform()
acti = 'cart_relu'
model = tf.keras.models.Sequential()
model.add(complex_layers.ComplexInput(input_shape=input_shape))                     
model.add(complex_layers.ComplexConv2D(32, (3, 3),padding = 'same', activation=acti, kernel_initializer=init))
model.add(complex_layers.ComplexMaxPooling2D((2, 2)))
model.add(complex_layers.ComplexConv2D(64, (3, 3),padding = 'same', activation=acti, kernel_initializer=init))
model.add(complex_layers.ComplexFlatten())
model.add(complex_layers.ComplexDense(64, activation=acti, kernel_initializer=init))
model.add(complex_layers.ComplexDense(10, activation='convert_to_real_with_abs', kernel_initializer=init)) 
print(model.summary())

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss=ComplexAverageCrossEntropy(),
              metrics=ComplexCategoricalAccuracy())
# train model
history = model.fit(X_train, y_train, epochs=10, validation_split=0.2, batch_size=32) 

# save the model by :
keras.models.save_model(model,"./models/model.hdf5")

2) For Load


#for the load it is necessary to add a custom_objects which contains all of complex objects used in the build of model
# for this exemple the load like 👍 

model = keras.models.load_model(
    "./models/model.hdf5",
    custom_objects={'Custom>Adam': keras.optimizer_experimental.adam.Adam,
                    'convert_to_real_with_abs': cvnn.activations.convert_to_real_with_abs,
                    'ComplexInput' :complex_layers.ComplexInput,
                    'ComplexConv2D' : complex_layers.ComplexConv2D,
                    'ComplexMaxPooling2D' :complex_layers.ComplexMaxPooling2D,
                    'ComplexFlatten': complex_layers.ComplexFlatten,
                    'ComplexDense': complex_layers.ComplexDense,
                    'ComplexAverageCrossEntropy' :cvnn.losses.ComplexAverageCrossEntropy ,
                    'ComplexCategoricalAccuracy' :cvnn.metrics.ComplexCategoricalAccuracy
                    }
    )

I hope that can help !

NEGU93 commented 2 years ago

Thank you for your contribution!