Open rgerkin opened 3 years ago
Hi @rgerkin,
This is a critical issue that is largely constrained by TensorFlow's serialization/deserialization of custom models and layers.
The short answer is that tf.keras
does not currently support tf.keras.models.model_to_json()
and tf.keras.models.model_from_json()
for custom models. These functions only work for Sequential
models or a Functional API models. Additional details can be found here: https://www.tensorflow.org/guide/keras/save_and_serialize#custom_objects.
For custom models, there are two options depending on whether you want to save to disk.
1) Save to disk: Use the PsiZ's custom functions model.save
and psiz.models.load_model()
2) Do not save to disk: Serialize all the components separately (i.e., architecture, weights, loss and optimizer).
This option saves to disk, which may not be what you want. To achieve this there are custom save and load methods. To save/load you would have something like the following:
import psiz
model = psiz.models.Rate(...)
filepath = 'path/to/model/directory` # Note the path is a directory not a file.
model.save(filepath)
reconstructed_model = psiz.models.load_model(filepath)
With the release of TF 2.4, the default TF save and load methods should now be able to handle custom objects and layers. I'm currently working on a commit to take advantage of this. Adding a serialization decorator to the custom models (e.g., Rank
and Rate
) solves some problems for loading/saving models in TF 2.4, but I don't know if it will solve the to_json
issue as well.
If you do not want to save to disk, you can serialize the architecture, weights, loss and optimizer separately. Assuming your Rate
model is called model
and model_class
corresponds to a model class constructor (e.g., Rate
):
model_config = model.get_config
and model_class.from_config(model_config)
.w = model.get_weights
and model.set_weights(w)
.optimizer_config = tf.keras.optimizers.serialize(model.optimizer)
and tf.keras.optimizers.deserialize(optimizer_config)
loss_config = tf.keras.losses.serialize(model.loss)
and tf.keras.losses.deserialize(loss_config)
I have done my best to implement the get_config
/from_config
API for all custom models and layers. I say "my best", because I do not have full test coverage yet.
All of this could be combined into a custom to_json()
and model_from_json()
similar to the custom function I wrote for model.save()
and psiz.models.load_model()
. This is on my TODO list, but I'm not sure when I will get this implemented and fully tested. In the past, I have had problems using get_weights
and set_weights
with custom models, so it's critical that this functionality be tested thoroughly.
Hopefully this helps, let me know if you need to_json
/from_json
functionality for your application.
@colemanliyah found another way which was to just regenerate the model the same code that created it in the first place, then to load the weights. Will this work or will it be missing something? In any case it sounds like maybe we should just use the psiz custom save_model()
and load_model()
you described.
Yes, that approach works. However, it is worth noting that if you ever want to smoothly resume training you will also want the state of the optimizer (e.g., momentum terms).
The dev
branch has been updated to require TF 2.4 and use TF 2.4 save/load (commit fa00fa8ba20073aaa2db8b5b2cf66a7f0602feaf). There is still an outstanding problem (see Issue #19), but it successfully replaces PsiZ's custom save/load strategy. Note that the save formats are NOT the same.
import tf.keras.model
# Save using TF 2.4 tf.keras.model save method.
fp = 'model_directory'
my_model.save(fp, save_traces=False)
# Load saved model.
loaded_model = tf.keras.models.load_model(fp)
Old models can still be loaded using psiz.models.load_model
. The safest strategy is load old models and resave them using the new format. The old load/save strategy is now deprecated.
@roads when I follow your example to save the whole model I get this error:
OSError: SavedModel file does not exist at: saving_whole_model/{saved_model.pbtxt|saved_model.pb}
where I saved the model with the name "saving_whole_model" Also saving the model only works if I do model.save(fp) without the "save_traces" because when I add "save_traces" I get an attribute error. My thinking is that we may be using two different save() functions somehow.
The last point to mention is if I use the line
psiz.models.load_model
I get a value error saying that there is an unknown layer "BehaviorLog" ... so far I have only been able to get save and load to work if I do it with the weights as Dr. Gerkin mentioned above.
Hi @colemanliyah,
Which version of psiz are you using? The new save/load strategy only applies if using the dev
branch, but will be available soon in release 0.4.2.
In regards to the BehaviorLog
error, I think this issue is tied up with the first issue. Having said that, the general pattern for loading models with custom layers is:
custom_objects = [BehaviorLog]
loaded_model = tf.keras.models.load_model(fp, custom_objects=custom_objects)
I've updated examples/rate/mle_1g.py
on the dev
branch with code that illustrates the save/load process.
@roads oh I see, I was using an old copy of psiz and not the updated dev branch...thank you for your help with this
No problem. Please don't hesitate to ask if you run into any other issues. The package is going through some growing pains, so there are many opportunities for confusion. I'm working on some proper documentation, but it still has a ways to go.
From the examples/rate directory, if I do:
I get
ValueError: Unknown layer: Rate
from the Keras deserialization routine. I would think this is because the psiz.models.rate.Rate model is not registered, so I added the@tf.keras.utils.register_keras_serializable(package='psiz.models.rate', name='Rate')
decorator to that model definition, but the problem did not go away. Looking at the json itself, theclass_name
is just "Rate", whereas other classes have a more instructiveclass_name
attribute that seems to suggest a registered class. Until this is solved I am not able to deserialize models (though I can work around for now by simply reconstructing them in Python code).