tensorflow / neural-structured-learning

Training neural models with structured signals.
https://www.tensorflow.org/neural_structured_learning
Apache License 2.0
980 stars 189 forks source link

Saving `AdversarialRegularization` models is currently not supported. Consider using `save_weights` or saving the `base_model` #113

Closed wangbingnan136 closed 2 years ago

wangbingnan136 commented 2 years ago

How to save the AdversarialRegularization model?

wangbingnan136 commented 2 years ago

I used the modelcheckpoint,and it returned the "Saving AdversarialRegularization models is currently not supported. Consider using save_weights or saving the base_model" error,I found that the "save" method of AdversarialRegularization was not implemented.

csferng commented 2 years ago

Hi @wangbingnan136, thanks for your question.

Yes, AdversarialRegularization.save() is not implemented. (PRs are more than welcome!) Depending on your use case, here are the suggested alternatives:

  1. For checkpointing and pausing/resuming the training, you may use AdversarialRegularization.save_weights() to save a checkpoint and AdversarialRegularization.load_weights() to load back a checkpoint.
  2. For exporting a trained model for evaluation and serving, you may use AdversarialRegularization.base_model.save() to export in TensorFlow SavedModel or HDF5 format. The exported base model can be loaded back using tf.keras.models.load_model(), and you may construct a new AdversarialRegularization object from the loaded model. (AdversarialRegularization class doesn't have any model variable, so the new object should behave the same as the original object.)
wangbingnan136 commented 2 years ago

Hi @wangbingnan136, thanks for your question.

Yes, AdversarialRegularization.save() is not implemented. (PRs are more than welcome!) Depending on your use case, here are the suggested alternatives:

  1. For checkpointing and pausing/resuming the training, you may use AdversarialRegularization.save_weights() to save a checkpoint and AdversarialRegularization.load_weights() to load back a checkpoint.
  2. For exporting a trained model for evaluation and serving, you may use AdversarialRegularization.base_model.save() to export in TensorFlow SavedModel or HDF5 format. The exported base model can be loaded back using tf.keras.models.load_model(), and you may construct a new AdversarialRegularization object from the loaded model. (AdversarialRegularization class doesn't have any model variable, so the new object should behave the same as the original object.)

Thanks for reply! I have pull a request already to fix the "save" bug.It's now only save the base model of adv_model.

csferng commented 2 years ago

Cosing this issue since PR #114 is merged.