adriangb / scikeras

Scikit-Learn API wrapper for Keras.
MIT License
239 stars 47 forks source link

Can't pickle trained model with callback to TensorBoard #236

Closed joooeey closed 3 years ago

joooeey commented 3 years ago

Description of the problem

I was excited about scikeras because it can interface with sklearn and the models can supposedly be pickled. Unfortunately scikeras.KerasClassifier can't be pickled when both of the following conditions are fulfilled:

The equivalent neural network from Keras can be pickled without issue.

Minimum, Complete, Verifiable Example

from joblib import dump
# from pickle import dump  # causes the same problem
from numpy import random

from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from scikeras.wrappers import KerasClassifier

# %% shared data

X = random.random((10, 6))
y = random.randint(2, size=10)

def build_fn():
    """Build sequential neural network."""
    model = Sequential()
    model.add(Dense(30, activation="relu", input_shape=(6, )))
    model.add(Dense(20, activation="relu"))
    model.add(Dense(1, activation="sigmoid"))


    return model

X = random.random((10, 6))
y = random.randint(2, size=10)

# %% scikeras classifier [breaks]

clf = KerasClassifier(
    callbacks=[TensorBoard("testlogs")],  # won't break without this line

clf =, y)  # won't break without this line

dump(clf, open("test_scikeras.pkl", "wb"))  # raises InvalidArgumentError

# %% same classifier in pure tf.keras [works]

model = build_fn()

dump(model, open("test_keras.pkl", "wb"))  # works

Stack Trace

The last line of the # %% scikeras classifier [break] block raises:

Traceback (most recent call last):

  File "/home/lukas/Desktop/", line 52, in <module>
    dump(clf, open("test_scikeras.pkl", "wb"))  # raises InvalidArgumentError

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/", line 482, in dump
    NumpyPickler(filename, protocol=protocol).dump(value)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 487, in dump

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/", line 282, in save
    return, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 603, in save
    self.save_reduce(obj=obj, *rv)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 717, in save_reduce

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/", line 282, in save
    return, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 971, in save_dict

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 997, in _batch_setitems

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/", line 282, in save
    return, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 931, in save_list

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 958, in _batch_appends

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/", line 282, in save
    return, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 603, in save
    self.save_reduce(obj=obj, *rv)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 717, in save_reduce

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/", line 282, in save
    return, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 971, in save_dict

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 997, in _batch_setitems

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/", line 282, in save
    return, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 971, in save_dict

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 997, in _batch_setitems

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/", line 282, in save
    return, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 603, in save
    self.save_reduce(obj=obj, *rv)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 717, in save_reduce

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/", line 282, in save
    return, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 971, in save_dict

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 997, in _batch_setitems

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/", line 282, in save
    return, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/", line 578, in save
    rv = reduce(self.proto)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/tensorflow/python/framework/", line 1000, in __reduce__
    return convert_to_tensor, (self._numpy(),)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/tensorflow/python/framework/", line 1039, in _numpy
    six.raise_from(core._status_to_exception(e.code, e.message), None)  # pylint: disable=protected-access

  File "<string>", line 3, in raise_from

InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.


adriangb commented 3 years ago

Thank you for the clean reproducible example.

Interestingly, SciKeras is actually the reason why you are able to pickle the pure-Keras model: we monkey patch tf.keras.Model to make it packable (here). If you remove the SciKeras import, that will fail with something along the lines of can't pickle weakref object.

The issue itself stems from pickling the callback. Unfortunately, many things in TensorFlow aren't picklable using standard Python pickling facilities. So SciKeras uses TensorFlow's own serialization support to serialize models. But TensorFlow's serialization support is limited and rigid. For example, it can serialize an entire Model, but not Callback or Optimizer instances. Generally this is OK because things like optimizer instances are stored as part of the Keras model itself, and so TensorFlow knows how to serialize it. But models don't hold references to callbacks, they are only passed in as fit/predict arguments. So SciKeras has to hold the reference to callbacks in order to support stateful callbacks. Which of course causes an issue when pickling (I had not thought of this when I implemented callback support, so thank you for bringing it up).

I'll have to think a bit on what can be done about this, but I hope that at least sheds some light on the issue for now.

adriangb commented 3 years ago

For what it's worth, The actual unpicklable object is TensorBoard._writers. Deleting it makes the callback picklable.

adriangb commented 3 years ago

@joooeey The quickest solution to your problem is going to be to pass the callback to SciKeras as a fit kwarg:

from pickle import dumps
from numpy import random

from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from scikeras.wrappers import KerasClassifier

# %% shared data

X = random.random((10, 6))
y = random.randint(2, size=10)

def build_fn():
    """Build sequential neural network."""
    model = Sequential()
    model.add(Dense(30, activation="relu", input_shape=(6, )))
    model.add(Dense(20, activation="relu"))
    model.add(Dense(1, activation="sigmoid"))


    return model

X = random.random((10, 6))
y = random.randint(2, size=10)

clf = KerasClassifier(

clf =, y, callbacks=[TensorBoard("testlogs")])


Note however that:

  1. The callback won't be serialized / saved if you pickle/unpickle the model.
  2. You can't hyperparameter tune it (not that it would make sense in this case). But you can pass it to GridSearchCV and other hyper parameter tuning tools if they support **fit_args
  3. You can't use partial_fit (and hence Dask) since our partial_fit doesn't support passing arbitrary arguments.
adriangb commented 3 years ago

@stsievert I think this use case is the nail in the coffin for having to keep around **kwargs

stsievert commented 3 years ago

the nail in the coffin for having to keep around **kwargs

Do you have a reference issue/PR? I presume you mean passing parameters/keyword arguments through fit instead of always specifying those parameters at initialization. Is that correct?

adriangb commented 3 years ago

Do you have a reference issue/PR?

We discussed the topic several times previously, eg. #198 , and #138

you mean passing parameters/keyword arguments through fit instead of always specifying those parameters at initialization

Yes, exactly.

What I mean that we'll have to keep both constructor parameters (sklearn style) and fit/predict **kwargs (keras style) around, and fully support both. In practice I just think this means removing any wording/warnings around kwargs deprecation that are left, and adding support to partial_fit.

The sklearn constructor parameters are necessary for the sklearn ecosystem to work (eg dask-ml hyperparameter tuning), and wherever possible we should encourage it, but certain Keras/TF use cases (in particular the one presented in this PR) simply aren't compatibile with sklearn style constructor parameters and require **kwargs.

stsievert commented 3 years ago

Yeah, this is good motivation to keep that behavior (passing keyword arguments through to

I think that behavior should be strongly discouraged.

joooeey commented 3 years ago

@joooeey The quickest solution to your problem is going to be to pass the callback to SciKeras as a fit kwarg: [...]

Unfortunately in my real code I have the fit classifier in a pipeline. So now I'm using, y, kerasclassifier__callbacks=[Tensorboard()])

This raises

UserWarning: Passing estimator parameters as keyword arguments (aka as `**kwargs`) to `fit` is not supported by the Scikit-Learn API, and will be removed in a future version of SciKeras.

To resolve this issue, either set these parameters in the constructor (e.g., `est = BaseWrapper(..., foo=bar)`) or via `set_params` (e.g., `est.set_params(foo=bar)`). The following parameters were passed to `fit`:

`callbacks=[<tensorflow.python.keras.callbacks.TensorBoard object at 0x7f539c414460>]`

More detail is available at

By the way, the link in the warning doesn't work.

adriangb commented 3 years ago

You can safely ignore that warning, as per above we'll probably remove it in the future. Sorry for any confusion this may cause. Other than the warning, does that work for you?

By the way, the link in the warning doesn't work.

You're right, thank you for catching that. This is the correct link, and I'll update the warning (or remove it):

joooeey commented 3 years ago

Yea this works for me now.

adriangb commented 3 years ago

Awesome, I'm glad we found you a solution, even if its not ideal.

Like I said above, we will probably disable those warnings so that this API will be more straightforward to use going forward.

Your feedback has been very valuable, so thank you for the issue and bearing with me during troubleshooting.