Closed ZafranShah closed 3 years ago
Indeed, this is still not implemented/checked. As my applications are more into research I just want the result and lose the model after that. It is on my TODO list but I don't know when I'll have the time. SORRY!
If you or anyone wants to fix this you can do a PR.
So it is now working. I don't know what happened, I decided to address this matter and tried it and work. This is what I tried based on this tutorial:
checkpoint_path = "training_1/cp.ckpt"
train_dataset, test_dataset = get_dataset_for_segmentation()
model = get_model()
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)
model.fit(x=train_dataset, epochs=2, validation_data=test_dataset, shuffle=True, callbacks=[cp_callback])
After that I tried the code:
train_dataset, test_dataset = get_dataset_for_segmentation()
model = get_model()
loss, acc = model.evaluate(test_dataset, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
model.load_weights(checkpoint_path)
loss, acc = model.evaluate(test_dataset, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
Which render the result:
7/7 - 5s - loss: 1.0357 - accuracy: 0.3303
Untrained model, accuracy: 33.03%
7/7 - 4s - loss: 0.8419 - accuracy: 0.6200
Restored model, accuracy: 62.00%
The problem is resolved now you can load the trained model by mentioning the customize layer in the load_model function such as:
model = keras.models.load_model(pathoftrainedmodel, custom_objects ={'ComplexInput': cvnn.layers.complex_input, ...})
I wrote all the custom layers in the custom object and now it works.
Hi,
I can train a model consist of complex convolutional and deconvolutional layers however while loading the trained complex model it pops up such a kind of error. Any idea how to get rid of this error? ................................. File "/home/shah/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 321, in class_and_config_for_serialized_keras_object raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) ValueError: Unknown layer: ComplexInput