Closed ragrawal closed 4 years ago
Thank you for opening this issue.
It seems that catboost does not play well with subclasses in general. Even if you don't use baikal at all, pickling an instance of a subclass will fail to recover the attributes of the subclass.
import dill
from catboost import CatBoostClassifier
class CatBoostClassifierSub(CatBoostClassifier):
def __init__(self):
super().__init__()
self._nodes = []
a = CatBoostClassifierSub()
b = dill.loads(dill.dumps(a))
assert hasattr(a, "_nodes") # passes
assert hasattr(b, "_nodes") # fails
I think this is because CatBoostClassifier
implements its own __getstate__
method (here) that persists specific attributes. So it seems when you subclass from it and want it to be pickle-able you have to override the __getstate__
and __setstate__
methods to call the method of the parent class and add the attributes of the child class. Something along these lines:
class CatBoostClassifierStep(Step, CatBoostClassifier):
def __init__(self, *args, name=None, **kwargs):
super().__init__(*args, name=name, **kwargs)
def __getstate__(self):
state = super().__getstate__()
state["_name"] = self._name
state["_nodes"] = self._nodes
state["_n_outputs"] = self._n_outputs
def __setstate__(self, state):
super().__setstate__(state)
self._name = state["_name"]
self._nodes = state["_nodes"]
self._n_outputs = state["_n_outputs"]
Can you try the above and see if it works (might need some adjustments, I haven't tried this myself)? (I don't think you need to implement __setstate__
and __getstate__
for any other classes)
Answering your question from the PR thread:
do you have an example of serialization ?
There is an example in the tests here.
For anyone else, the solution suggested by Alegonz actually works. There was just one problem. getstate didn't include return statement. For reference this is how to create CatbootClassifier Step
class CatBoostClassifierStep(Step, CatBoostClassifier):
def __init__(self, *args, name=None, n_outputs=1, **kwargs):
super().__init__(*args, name=name, n_outputs=n_outputs, **kwargs)
self._nodes = []
def __getstate__(self):
state = super().__getstate__()
state["_name"] = self._name
state["_nodes"] = self._nodes
state["_n_outputs"] = self._n_outputs
return state. // make sure to return the state
def __setstate__(self, state):
self._name = state["_name"]
self._nodes = state["_nodes"]
self._n_outputs = state["_n_outputs"]
super().__setstate__(state)
def __hash__(self): // include this otherwise you will get hashing error
return hash(super().name)
@ragrawal Thank you for confirming! I'll add an entry in the documentation about this special case of catboost estimators.
I'm planning to patch baikal to address the case of un-hashable steps reported in Issue #37 and avoid the __hash__
workaround.
What is the bug? When serializing/deserializing CatBoost model using joblib, the deserialization fails with the following error
object has no attribute '_name'
How to reproduce it?
What versions are you using? baikal==0.4.1 scikit-learn==0.23.1 python==3.6
Any additional information?