alegonz / baikal

A graph-based functional API for building complex scikit-learn pipelines.
https://baikal.readthedocs.io
BSD 3-Clause "New" or "Revised" License
592 stars 30 forks source link

Deserialization fails when directly inherit Step class #40

Closed ragrawal closed 4 years ago

ragrawal commented 4 years ago

What is the bug? When serializing/deserializing CatBoost model using joblib, the deserialization fails with the following error object has no attribute '_name'

How to reproduce it?

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from baikal import Input, Model, make_step, Step
from baikal.steps import Stack
from catboost import CatBoostClassifier
from sklearn.linear_model import LinearRegression

# load data
df = pd.read_csv('prime.csv',
#     'https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv', 
    header=None)
dataset = df.values

class CatBoostClassifierStep(Step, CatBoostClassifier):
    def __init__(self, *args, name=None, n_outputs=1, **kwargs):
        super().__init__(*args, name=name, n_outputs=n_outputs, **kwargs)

    def __hash__(self):
        return hash(super().name)

x = Input()
y = Input()

xgbStep = LinearRegressionStep()(x, y)
model = Model(x, xgbStep, y)
model.fit(dataset[:,0:8], dataset[:,8])

import joblib
s = joblib.dump(model, 't.pkl')

model2 = joblib.load('t.pkl')
model2.predict(dataset[:,0:8])

What versions are you using? baikal==0.4.1 scikit-learn==0.23.1 python==3.6

Any additional information?

AttributeError          Traceback (most recent call last)
<ipython-input-5-139f76755734> in <module>
     31 
     32 model2 = joblib.load('t.pkl')
---> 33 model2.predict(dataset[:,0:8])

/opt/anaconda3/envs/plaid/lib/python3.7/site-packages/baikal/_core/model.py in predict(self, X, output_names)
    468         # without having to change the inputs accordingly.
    469         nodes = self._get_required_nodes(
--> 470             X_norm, [], outputs, allow_unused_inputs=True, follow_targets=False
    471         )
    472 

/opt/anaconda3/envs/plaid/lib/python3.7/site-packages/baikal/_core/model.py in _get_required_nodes(self, given_inputs, given_targets, desired_outputs, allow_unused_inputs, allow_unused_targets, follow_targets, ignore_trainable_false)
    189 
    190         for output in desired_outputs:
--> 191             required_nodes |= backtrack(output)
    192 
    193         # Check for missing inputs/targets

/opt/anaconda3/envs/plaid/lib/python3.7/site-packages/baikal/_core/model.py in backtrack(output)
    174                 return nodes_required_by_output
    175 
--> 176             parent_node = output.node
    177             if parent_node in required_nodes:
    178                 return nodes_required_by_output

/opt/anaconda3/envs/plaid/lib/python3.7/site-packages/baikal/_core/data_placeholder.py in node(self)
     42     @property
     43     def node(self):
---> 44         return self.step._nodes[self.port]
     45 
     46     def __repr__(self):

AttributeError: 'CatBoostClassifierStep' object has no attribute '_nodes'
alegonz commented 4 years ago

Thank you for opening this issue.

It seems that catboost does not play well with subclasses in general. Even if you don't use baikal at all, pickling an instance of a subclass will fail to recover the attributes of the subclass.

import dill
from catboost import CatBoostClassifier

class CatBoostClassifierSub(CatBoostClassifier):
    def __init__(self):
        super().__init__()
        self._nodes = []

a = CatBoostClassifierSub()
b = dill.loads(dill.dumps(a))

assert hasattr(a, "_nodes")  # passes
assert hasattr(b, "_nodes")  # fails

I think this is because CatBoostClassifier implements its own __getstate__ method (here) that persists specific attributes. So it seems when you subclass from it and want it to be pickle-able you have to override the __getstate__ and __setstate__ methods to call the method of the parent class and add the attributes of the child class. Something along these lines:

class CatBoostClassifierStep(Step, CatBoostClassifier):
    def __init__(self, *args, name=None, **kwargs):
        super().__init__(*args, name=name, **kwargs)

    def __getstate__(self):
        state = super().__getstate__()
        state["_name"] = self._name
        state["_nodes"] = self._nodes
        state["_n_outputs"] = self._n_outputs

    def __setstate__(self, state):
        super().__setstate__(state)
        self._name = state["_name"]
        self._nodes = state["_nodes"]
        self._n_outputs = state["_n_outputs"]

Can you try the above and see if it works (might need some adjustments, I haven't tried this myself)? (I don't think you need to implement __setstate__ and __getstate__ for any other classes)


Answering your question from the PR thread:

do you have an example of serialization ?

There is an example in the tests here.

ragrawal commented 4 years ago

For anyone else, the solution suggested by Alegonz actually works. There was just one problem. getstate didn't include return statement. For reference this is how to create CatbootClassifier Step

class CatBoostClassifierStep(Step, CatBoostClassifier):
    def __init__(self, *args, name=None, n_outputs=1, **kwargs):
        super().__init__(*args, name=name, n_outputs=n_outputs, **kwargs)
        self._nodes = []

    def __getstate__(self):
        state = super().__getstate__()
        state["_name"] = self._name
        state["_nodes"] = self._nodes
        state["_n_outputs"] = self._n_outputs
        return state. // make sure to return the state

    def __setstate__(self, state):
        self._name = state["_name"]
        self._nodes = state["_nodes"]
        self._n_outputs = state["_n_outputs"]
        super().__setstate__(state)

    def __hash__(self):  // include this otherwise you will get hashing error
        return hash(super().name)
alegonz commented 4 years ago

@ragrawal Thank you for confirming! I'll add an entry in the documentation about this special case of catboost estimators.

I'm planning to patch baikal to address the case of un-hashable steps reported in Issue #37 and avoid the __hash__ workaround.