Closed cshimmin closed 3 years ago
so this question is left never answered ...?
Use save_weights to save the model. When loading, build&compile a model first, then use load_weights. Maybe you could compile the model after load weights. I didn't try.
@xing-w, the issue I describe is not that save_weights
does not work as expected. The issue is specifically regarding saving the model. Of course I can just re-instantiate the exact same keras model and load in the weights. But for this to be practically useful one would need to also export a bunch of metadata about how to re-build the model to a separate file. Not to mention that saving/loading weights does not solve the issue of saving the optimizer state. All of this is exactly what Keras' Model.save()
is supposed to do, except that it does not work as expected, as reported here.
I ended up writing a custom Callback in a VAE setup to save the encoder and decoder separately. This may be relevant to your setup too.
Is this a solution ? https://stackoverflow.com/questions/48373845/loading-model-with-custom-loss-keras
so this question no answer? is bug?
The custom_objects
kwarg to model saving and loading worked great for me.
Would you expand on how you used the custom_objects
kwarg please, QCaudron?
See docs here ( under "Handling custom layers (or other custom objects) in saved models" ) : https://keras.io/getting-started/faq/
Here's some of my code that works well.
# Loss function : the Jaccard index
def jaccard_coef(y_true, y_pred):
"""
The Jaccard index, where 1 is the perfect overlap.
"""
y_true_flat = K.flatten(y_true)
y_pred_flat = K.flatten(y_pred)
intersection = K.sum(y_true_flat * y_pred_flat)
union = K.sum(y_true_flat) + K.sum(y_pred_flat) - intersection
return -intersection / (union + 1e-6)
# Loading the model
model = load_model(
"unet_profond.h5",
custom_objects={"jaccard_coef": jaccard_coef}
)
Your function must be defined in scope at load time, so you can refer to it. Saving is the same.
I hope people encounter this issue with VAEs (specifically when loading the vae model in the keras example vae code). Please see my answer here for a possible solution and example implementation.
@ihalage This is no fix for the issue. @cshimmin already mentioned this approach and it is not practical.
@GPla First, this is not a fix for this particular issue, it is a workaround to save and load a model, specifically in VAEs (which is mentioned in my comment). Second, I have practically implemented and tested it.
I am trying to save models which have custom loss functions that are added to the model using
Model.add_loss()
. This is NOT the same issue which has already been seen several times, where you have to passcustom_objects=...
toload_model()
; in fact, when usingadd_loss
, I do not include any loss function when callingModel.compile()
.Here is a brief script that can reproduce the issue:
Everything in this script works as expected, except for the last line. Instead of loading the model, I get an error:
Of course, it should be noted that this model is a toy example and is not doing something super interesting (it will just learn to output zeros). One actual application where I regularly have this problem is with variational autoencoders, where the reconstruction and KL losses are added in this manner, and there is no explicit comparison of
y_pred
andy_test
, since the loss in such models is also completely defined by the input, as well as depending on other things such as internal layer outputs for the encoder means and variances.Is there some way with Keras to save/load these kinds of models, where the loss is defined explicitly as some kind of tensor expression?
Please make sure that the boxes below are checked before you submit your issue. If your issue is an implementation question, please ask your question on StackOverflow or on the Keras Slack channel instead of opening a GitHub issue.
Thank you!
[x] Check that you are up-to-date with the master branch of Keras. You can update with:
pip install git+git://github.com/keras-team/keras.git --upgrade --no-deps
[x] Check that your version of TensorFlow is up-to-date. The installation instructions can be found here.
[x] Provide a link to a GitHub Gist of a Python script that can reproduce your issue (or just copy the script here if it is short).