adriangb / scikeras

Scikit-Learn API wrapper for Keras.
https://www.adriangb.com/scikeras/
MIT License
239 stars 47 forks source link

Can't load saved model (UnicodeDecodeError) #309

Closed g811201 closed 5 months ago

g811201 commented 10 months ago

I tried to use scikeras KerasRegressor to wrap my model.

` import warnings from tensorflow import get_logger get_logger().setLevel('ERROR') warnings.filterwarnings("ignore", message="Setting the random state for TF") import numpy as np from scikeras.wrappers import KerasRegressor from tensorflow import keras from sklearn.datasets import make_regression

X_regr, y_regr = make_regression(1000, 20, n_informative=10, random_state=0)

X_regr.shape, y_regr.shape, y_regr.min(), y_regr.max() def get_reg(meta, hidden_layer_sizes, dropout): n_featuresin = meta["n_featuresin"] model = keras.models.Sequential() model.add(keras.layers.Input(shape=(n_featuresin,))) for hidden_layer_size in hidden_layer_sizes: model.add(keras.layers.Dense(hidden_layer_size, activation="relu")) model.add(keras.layers.Dropout(dropout)) model.add(keras.layers.Dense(1)) return model from scikeras.wrappers import KerasRegressor

reg = KerasRegressor( model=get_reg, loss="mse", metrics=[KerasRegressor.r_squared], hidden_layer_sizes=(100,), dropout=0.5, ) reg.fit(X_regr, y_regr);

y_pred = reg.predict(X_regr[:5]) y_pred

import pickle bytes_model = pickle.dumps(reg) new_reg = pickle.loads(bytes_model) new_reg.predict(X_regr[:5]) # model is still trained `

(same as the example: https://adriangb.com/scikeras/stable/notebooks/Basic_Usage.html)

I can train and save the model, but I can't load it. (with pickle)

I got:

` Traceback (most recent call last):

File D:\miniforge\envs\develop\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec exec(code, globals, locals)

File d:\untitled3.py:52 new_reg = pickle.loads(bytes_model)

File D:\miniforge\envs\develop\Lib\site-packages\scikeras_saving_utils.py:50 in unpack_keras_model model.load_weights(temp_dir)

File D:\miniforge\envs\develop\Lib\site-packages\keras\src\utils\traceback_utils.py:70 in error_handler raise e.with_traceback(filtered_tb) from None

File D:\miniforge\envs\develop\Lib\site-packages\tensorflow\python\training\py_checkpoint_reader.py:95 in NewCheckpointReader except RuntimeError as e:

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa6 in position 181: invalid start byte `

env: OS: windows10 22H2 python==3.11.6 tensorflow==2.14.0 scikit-learn==1.13.1 scikeras==0.12.0

p.s. I can't load the model in windows. However, I can load the model successfully in linux.

adriangb commented 5 months ago

I'm guessing this is fixed now. Please let me know if this is still not fixed with SciKeras 0.13.0. Thanks!