bryanlimy / tf2-transformer-chatbot

Transformer Chatbot in TensorFlow 2 with TPU support.
MIT License
130 stars 58 forks source link

Saving model?? #16

Closed Rob-Byrne closed 2 years ago

Rob-Byrne commented 3 years ago

I have been banging my head trying to save the model for later evaluation - do you have code that will do this (i.e. generate config for the model of embed for model.save() functionality?

bryanlimy commented 3 years ago

Hi @Rob-Byrne You can call save_weights and load_weights

filename = 'weights.h5'
model.save_weights(filename)

new_model = transformer(hparams)
new_model.load_weights(filename)
frankplus commented 3 years ago

model.save_weights(filename) does not save the optimizer state, so it's not possible to restore and resume the training after it's stopped. by using model.save() it should save the optimizer state but I cannot make it to work, it says "Not JSON Serializable"

bryanlimy commented 3 years ago

Yes @frankplus, with model subclassing, we have to overwrite the get_config and from_config methods, see https://www.tensorflow.org/guide/keras/save_and_serialize#custom_objects.

This was already done in an earlier PR https://github.com/bryanlimy/tf2-transformer-chatbot/pull/13 though only in jupyter notebook. You can take a look, should be relatively simple to port copy the changes over.

aqibfayyaz commented 3 years ago

model.save_weights(filename) does not save the optimizer state, so it's not possible to restore and resume the training after it's stopped. by using model.save() it should save the optimizer state but I cannot make it to work, it says "Not JSON Serializable"

Hi were you successful?

aqibfayyaz commented 3 years ago

Hi @Rob-Byrne You can call save_weights and load_weights

filename = 'weights.h5'
model.save_weights(filename)

new_model = transformer(hparams)
new_model.load_weights(filename)

please can you upload the code which saves the model too?

bryanlimy commented 2 years ago

https://github.com/bryanlimy/tf2-transformer-chatbot/pull/23