keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

How to save Scikit-Learn-Keras Model into a Persistence File (pickle/hd5/json/yaml) with Keras 3 #19650

Closed CAW9 closed 1 week ago

CAW9 commented 2 weeks ago

Hello. I have used the following solution for several years without issue, but upon upgrading to Keras 3, I am wondering what changes need to be made to fix this part of our code and migrate successfully:

The below issue presents the original solution:

https://github.com/keras-team/keras/issues/4274

from scikeras.wrappers import KerasClassifier as KCsci

    class KerasClassifier(KCsci):

        def __getstate__(self):
            state = self.__dict__
            if "model_" in state:
                model = state["model_"]
                model_hdf5_bio = io.BytesIO()
                with h5py.File(model_hdf5_bio, 'w') as file:
                    model.save(file)
                state["model_"] = model_hdf5_bio
                state_copy = copy.deepcopy(state)
                state["model_"] = model
                return state_copy
            else:
                return state

        def __setstate__(self, state):
            if "model_" in state:
                model_hdf5_bio = state["model_"]
                with h5py.File(model_hdf5_bio, 'r') as file:
                    state["model_"] = tf.keras.models.load_model(file)
            self.__dict__ = state

Please note that I am trying to use python 3.12, scikeras 0.13.0, tensorflow 2.16.1, and keras 3.2.1.

A large part of the issue seems to be that model.save() now requires a .keras extension, which I have tried to satisfy but have failed thus far, because the file object simply is not a file path. I have also tried switching to model.export(), as the migration notes suggest, but can't get that quite right either. I have also tried model.save-ing to a temp file with the .keras extension and reading that temp file into the io.BytesIO() object, which raised no errors in getstate, but I was unable to unpackage in setstate.

Thank you for any help and wisdom that you can provide.

With gratitude, CAW9

fchollet commented 2 weeks ago

What are you trying to do here? Pickle a custom class that includes a Keras model?

CAW9 commented 2 weeks ago

Not exactly. I am trying to enable pickle dumps for a KerasClassifier wrapper. Previously, it was proposed in https://github.com/keras-team/keras/issues/4274 that this should be done with the following code:

class KerasClassifier(tf.keras.wrappers.scikit_learn.KerasClassifier): """ TensorFlow Keras API neural network classifier.

Workaround the tf.keras.wrappers.scikit_learn.KerasClassifier
serialization
issue using BytesIO and HDF5 in order to enable pickle dumps.

Adapted from:
https://github.com/keras-team/keras/issues/4274#issuecomment-519226139
"""

def __getstate__(self):
    state = self.__dict__
    if "model" in state:
        model = state["model"]
        model_hdf5_bio = io.BytesIO()
        with h5py.File(model_hdf5_bio, 'w') as file:
            model.save(file)
            # tf.keras.models.save_model(model, file, save_format="h5")
        state["model"] = model_hdf5_bio
        state_copy = copy.deepcopy(state)
        state["model"] = model
        return state_copy
    else:
        return state

def __setstate__(self, state):
    if "model" in state:
        model_hdf5_bio = state["model"]
        with h5py.File(model_hdf5_bio, 'r') as file:
            state["model"] = tf.keras.models.load_model(file)
    self.__dict__ = state

Since then, tf.keras.wrappers.scikit_learn.KerasClassifier has been deprecated and essentially replaced by scikeras.

The reason for the custom getstate and setstate are because these workarounds were historically needed (and recommended in the linked issue) to serialize the KerasClassifier wrapper object. Migrating to scikeras was not a problem, but now that python 3.12, tensorflow 2.16.1, and keras 3.x are being used, the previously suggested mechanisms used to serialize a KerasClassifier have broken down.

Thank you again!

fchollet commented 2 weeks ago

Did you try just using pickle with no workaround? It might well work with Keras 3.3.3.

CAW9 commented 2 weeks ago

The KerasClassifier object (with no workarounds attempted) does serialize just fine before fitting:

KerasClassifier( model=<function build_tf_estimator.< locals >.build_tf_model at 0x144b60220> build_fn=None warm_start=False random_state=None optimizer=rmsprop loss=None metrics=None batch_size=200 validation_batch_size=None verbose=2 callbacks=None validation_split=0.0 shuffle=True run_eagerly=False epochs=30 class_weight=None )

p = dill.dumps(est) print(dill.loads(p))


But after fitting I get a large cascade of errors:

est.fit(get_df_values(X), y, **kwargs) p = dill.dumps(est) print(dill.loads(p))

File "/Users/cgladue/downloads/py312/automl_model.py", line 202, in fit_on_all_data print(dill.loads(p)) ^^^^^^^^^^^^^ File "/Users/cgladue/downloads/py312/dill/_dill.py", line 303, in loads return load(file, ignore, kwds) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/cgladue/downloads/py312/dill/_dill.py", line 289, in load return Unpickler(file, ignore=ignore, kwds).load() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/cgladue/downloads/py312/dill/_dill.py", line 444, in load obj = StockUnpickler.load(self) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/cgladue/downloads/py312/scikeras/_saving_utils.py", line 15, in unpack_keras_model return load_model(b, compile=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/cgladue/downloads/py312/keras/src/saving/saving_lib.py", line 141, in load_model return _load_model_from_fileobj( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/cgladue/downloads/py312/keras/src/saving/saving_lib.py", line 170, in _load_model_from_fileobj model = deserialize_keras_object( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/cgladue/downloads/py312/keras/src/saving/serialization_lib.py", line 720, in deserialize_keras_object raise TypeError( TypeError: <class 'keras.src.models.sequential.Sequential'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by get_config() are explicitly deserialized in the model's from_config() method.

Along with several other errors of the same form, like

Exception encountered: <class 'keras.src.layers.core.dense.Dense'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by get_config() are explicitly deserialized in the model's from_config() method.

I will look into this get_config and from_config suggestion, but I do not yet understand why fitting the model would make it unserializable.

SuryanarayanaY commented 2 weeks ago

Hi @CAW9 ,

If the model constructor has Keras layers then you need to implement get_config and from_config methods explicitly.

CAW9 commented 2 weeks ago

I've reproduced the error using only keras, and no scikeras dependency ( https://colab.research.google.com/drive/1ps1jt8WMINt0mOqvnGkE2sNdTqHLNkW1?usp=sharing ). This time, it fails to serialize before fitting.

To clarify, you are suggesting that I cannot define and use a simple keras model like this:

def build_tf_model(): model = tf.keras.models.Sequential([ tf.keras.layers.Dense( 10, input_dim=X_shape[1], activation=tf.nn.relu), tf.keras.layers.Dense(y_nunique, activation=tf.nn.softmax) ]) model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=0.2, amsgrad=True, name='Adam'), loss='sparse_categorical_crossentropy', ) return model

I need to instead implement a custom wrapper for a keras model in a class and in that class define custom get_config and from_config?

CAW9 commented 2 weeks ago

If in that example you change keras from 3.3.3 to 2.15.0, the serialization does not fail.

SuryanarayanaY commented 2 weeks ago

I have reproduced the behaviour with tf-nightly(Keras3). Wigth TF2.15 it works fine. Attached gist here.

SuryanarayanaY commented 2 weeks ago

The code is failing at keras.activations.get(identifier) step. But when I pass the activation tf.nn.relu directly its not failing.

identifier = tf.nn.relu
config = keras.activations.get(identifier)
print(type(config))
callable(config)

# Output
3.3.3.dev2024050303
<class 'function'>
True
CAW9 commented 2 weeks ago

I was able to fix my code by changing:

activation=tf.nn.relu to activation="relu" and activation=tf.nn.softmax to activation="softmax"

This is a satisfactory workaround for me.

If you would like to close the issue, I support you. If you feel that there is still a bug you need to address, feel free to leave it open.

Thank you for your help on this!

SuryanarayanaY commented 2 weeks ago

Identifier as a string it's working.IMO this is still a bug when identifier is either a dict or a function.

fchollet commented 2 weeks ago

Passing TF objects directly (e.g. tf.nn.softmax) does not play well with serialization. Make sure to pass Keras objects -- could be "softmax" or keras.ops.softmax.

When passing an external object, the object should be passed via the custom_objects dict at deserialization time (e.g. in load_model or deserialize_keras_object. However, because TF is nuts, the name of your objects aren't what you expect (e.g. tf.nn.softmax is named softmax_v2) so you have to take that into account (e.g. pass custom_objects={"softmax_v2": tf.nn.softmax}).

google-ml-butler[bot] commented 1 week ago

Are you satisfied with the resolution of your issue? Yes No