ageron / handson-ml3

A series of Jupyter notebooks that walk you through the fundamentals of Machine Learning and Deep Learning in Python using Scikit-Learn, Keras and TensorFlow 2.
Apache License 2.0
7.84k stars 3.14k forks source link

[QUESTION] CHAPTER 2: Custom Transformer Issue with Joblib #134

Open AdrianVazquezTorres opened 5 months ago

AdrianVazquezTorres commented 5 months ago

Hi there, I've been dealing with a problem in the last few days with the custom transformer ClusterSimilarity (notebook: 02_end_to_end_machine_learning_project | cell: 96). It turns out that the custom transformer works correctly when doing the calculations, but when I try to save it using joblib it gives me the following error:

'NoneType' object is not iterable

Whether I try to save it within a pipeline or separately, it still gives me the same error. I realized that the problem is in the transform function, if I remove it I can save the custom transformer without problems.

IMPORTANT: This problem appears with all the custom transformers that I try to save NOT only with this one.

My code is exactly the same as the book:

class ClusterSimilarity(BaseEstimator, TransformerMixin):
    def __init__(self, n_clusters=10, gamma=1.0, random_state=None):
        self.n_clusters = n_clusters
        self.gamma = gamma
        self.random_state = random_state

    def fit(self, X, y=None, sample_weight=None):
        self.kmeans_ = KMeans(self.n_clusters, n_init=10,random_state=self.random_state)
        self.kmeans_.fit(X, sample_weight=sample_weight)
        return self

    def transform(self, X):
        return rbf_kernel(X, self.kmeans_.cluster_centers_, gamma=self.gamma)

    def get_feature_names_out(self, names=None):
        return [f"Cluster {i} similarity" for i in range(self.n_clusters)]

cluster_simil = ClusterSimilarity(n_clusters=10, gamma=1, random_state=42)

Here is the complete error:

TypeError                                 Traceback (most recent call last)
Cell In[43], line 1
----> 1 joblib.dump(cluster_simil, "test.pkl")

File ~\anaconda3\envs\data_science_py310\lib\site-packages\joblib\numpy_pickle.py:553, in dump(value, filename, compress, protocol, cache_size)
    551 elif is_filename:
    552     with open(filename, 'wb') as f:
--> 553         NumpyPickler(f, protocol=protocol).dump(value)
    554 else:
    555     NumpyPickler(filename, protocol=protocol).dump(value)

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:487, in _Pickler.dump(self, obj)
    485 if self.proto >= 4:
    486     self.framer.start_framing()
--> 487 self.save(obj)
    488 self.write(STOP)
    489 self.framer.end_framing()

File ~\anaconda3\envs\data_science_py310\lib\site-packages\joblib\numpy_pickle.py:355, in NumpyPickler.save(self, obj)
    352     wrapper.write_array(obj, self)
    353     return
--> 355 return Pickler.save(self, obj)

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:603, in _Pickler.save(self, obj, save_persistent_id)
    599     raise PicklingError("Tuple returned by %s must have "
    600                         "two to six elements" % reduce)
    602 # Save the reduce() output and finally memoize the object
--> 603 self.save_reduce(obj=obj, *rv)

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:687, in _Pickler.save_reduce(self, func, args, state, listitems, dictitems, state_setter, obj)
    684     raise PicklingError(
    685         "args[0] from __newobj__ args has the wrong class")
    686 args = args[1:]
--> 687 save(cls)
    688 save(args)
    689 write(NEWOBJ)

File ~\anaconda3\envs\data_science_py310\lib\site-packages\joblib\numpy_pickle.py:355, in NumpyPickler.save(self, obj)
    352     wrapper.write_array(obj, self)
    353     return
--> 355 return Pickler.save(self, obj)

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:560, in _Pickler.save(self, obj, save_persistent_id)
    558 f = self.dispatch.get(t)
    559 if f is not None:
--> 560     f(self, obj)  # Call unbound method with explicit self
    561     return
    563 # Check private dispatch table if any, or else
    564 # copyreg.dispatch_table

File ~\anaconda3\envs\data_science_py310\lib\site-packages\dill\_dill.py:1832, in save_type(pickler, obj, postproc_list)
   1829     postproc_list.append((setattr, (obj, '__qualname__', qualname)))
   1831 if not hasattr(obj, '__orig_bases__'):
-> 1832     _save_with_postproc(pickler, (_create_type, (
   1833         type(obj), obj.__name__, obj.__bases__, _dict
   1834     )), obj=obj, postproc_list=postproc_list)
   1835 else:
   1836     # This case will always work, but might be overkill.
   1837     _metadict = {
   1838         'metaclass': type(obj)
   1839     }

File ~\anaconda3\envs\data_science_py310\lib\site-packages\dill\_dill.py:1098, in _save_with_postproc(pickler, reduction, is_pickler_dill, obj, postproc_list)
   1095     pickler._postproc[id(obj)] = postproc_list
   1097 # TODO: Use state_setter in Python 3.8 to allow for faster cPickle implementations
-> 1098 pickler.save_reduce(*reduction, obj=obj)
   1100 if is_pickler_dill:
   1101     # pickler.x -= 1
   1102     # print(pickler.x*' ', 'pop', obj, id(obj))
   1103     postproc = pickler._postproc.pop(id(obj))

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:692, in _Pickler.save_reduce(self, func, args, state, listitems, dictitems, state_setter, obj)
    690 else:
    691     save(func)
--> 692     save(args)
    693     write(REDUCE)
    695 if obj is not None:
    696     # If the object is already in the memo, this means it is
    697     # recursive. In this case, throw away everything we put on the
    698     # stack, and fetch the object back from the memo.

File ~\anaconda3\envs\data_science_py310\lib\site-packages\joblib\numpy_pickle.py:355, in NumpyPickler.save(self, obj)
    352     wrapper.write_array(obj, self)
    353     return
--> 355 return Pickler.save(self, obj)

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:560, in _Pickler.save(self, obj, save_persistent_id)
    558 f = self.dispatch.get(t)
    559 if f is not None:
--> 560     f(self, obj)  # Call unbound method with explicit self
    561     return
    563 # Check private dispatch table if any, or else
    564 # copyreg.dispatch_table

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:902, in _Pickler.save_tuple(self, obj)
    900 write(MARK)
    901 for element in obj:
--> 902     save(element)
    904 if id(obj) in memo:
    905     # Subtle.  d was not in memo when we entered save_tuple(), so
    906     # the process of saving the tuple's elements must have saved
   (...)
    910     # could have been done in the "for element" loop instead, but
    911     # recursive tuples are a rare thing.
    912     get = self.get(memo[id(obj)][0])

File ~\anaconda3\envs\data_science_py310\lib\site-packages\joblib\numpy_pickle.py:355, in NumpyPickler.save(self, obj)
    352     wrapper.write_array(obj, self)
    353     return
--> 355 return Pickler.save(self, obj)

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:560, in _Pickler.save(self, obj, save_persistent_id)
    558 f = self.dispatch.get(t)
    559 if f is not None:
--> 560     f(self, obj)  # Call unbound method with explicit self
    561     return
    563 # Check private dispatch table if any, or else
    564 # copyreg.dispatch_table

File ~\anaconda3\envs\data_science_py310\lib\site-packages\dill\_dill.py:1217, in save_module_dict(pickler, obj)
   1214     if is_dill(pickler, child=False) and pickler._session:
   1215         # we only care about session the first pass thru
   1216         pickler._first_pass = False
-> 1217     StockPickler.save_dict(pickler, obj)
   1218     logger.trace(pickler, "# D2")
   1219 return

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:972, in _Pickler.save_dict(self, obj)
    969     self.write(MARK + DICT)
    971 self.memoize(obj)
--> 972 self._batch_setitems(obj.items())

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:998, in _Pickler._batch_setitems(self, items)
    996     for k, v in tmp:
    997         save(k)
--> 998         save(v)
    999     write(SETITEMS)
   1000 elif n:

File ~\anaconda3\envs\data_science_py310\lib\site-packages\joblib\numpy_pickle.py:355, in NumpyPickler.save(self, obj)
    352     wrapper.write_array(obj, self)
    353     return
--> 355 return Pickler.save(self, obj)

File ~\anaconda3\envs\data_science_py310\lib\pickle.py:560, in _Pickler.save(self, obj, save_persistent_id)
    558 f = self.dispatch.get(t)
    559 if f is not None:
--> 560     f(self, obj)  # Call unbound method with explicit self
    561     return
    563 # Check private dispatch table if any, or else
    564 # copyreg.dispatch_table

File ~\anaconda3\envs\data_science_py310\lib\site-packages\dill\_dill.py:1960, in save_function(pickler, obj)
   1955 if globs_copy is not None and globs is not globs_copy:
   1956     # In the case that the globals are copied, we need to ensure that
   1957     # the globals dictionary is updated when all objects in the
   1958     # dictionary are already created.
   1959     glob_ids = {id(g) for g in globs_copy.values()}
-> 1960     for stack_element in _postproc:
   1961         if stack_element in glob_ids:
   1962             _postproc[stack_element].append((_setitems, (globs, globs_copy)))

TypeError: 'NoneType' object is not iterable

Versions:

Thank you for reading.