david-cortes / contextualbandits

Python implementations of contextual bandits algorithms
http://contextual-bandits.readthedocs.io
BSD 2-Clause "Simplified" License
751 stars 148 forks source link

XGBClassifier becomes un-serializable after being used as a base_model #54

Closed tinyrickguy closed 2 years ago

tinyrickguy commented 2 years ago

When using XGBoost models as base models for online models such as EpsilonGreedy, properties are being added to XGBoost models that prevent the models from being deserialized with either pickle or dill, as these properties don't exist in the original model class.

usage:

import contextualbandits as cb
from xgboost import XGBClassifier

arms = ["a", "b", "c"]
base_model = XGBClassifier(n_estimators=20)
cb_model = cb.online.EpsilonGreedy(base_model, nchoices=arms)
X = pd.DataFrame([0])
a = pd.Series(["a"])
r = pd.Series([1])
cb_model.fit(X, a, r)

dill.loads(dill.dumps(cb_model)) # dill.loads() fails with the error below:
AttributeError                            Traceback (most recent call last)
/var/folders/g5/lpnvjwrd2h95lf50zlb22dt00000gn/T/ipykernel_70859/3296061279.py in <module>
      9 cb_model.fit(X, a, r)
     10 
---> 11 dill.loads(dill.dumps(cb_model))

~/miniconda3/envs/env1/lib/python3.9/site-packages/dill/_dill.py in loads(str, ignore, **kwds)
    385     """
    386     file = StringIO(str)
--> 387     return load(file, ignore, **kwds)
    388 
    389 # def dumpzs(obj, protocol=None):

~/miniconda3/envs/env1/lib/python3.9/site-packages/dill/_dill.py in load(file, ignore, **kwds)
    371     See :func:`loads` for keyword arguments.
    372     """
--> 373     return Unpickler(file, ignore=ignore, **kwds).load()
    374 
    375 def loads(str, ignore=None, **kwds):

~/miniconda3/envs/env1/lib/python3.9/site-packages/dill/_dill.py in load(self)
    644 
    645     def load(self): #NOTE: if settings change, need to update attributes
--> 646         obj = StockUnpickler.load(self)
    647         if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
    648             if not self._ignore:

AttributeError: 'XGBClassifier' object has no attribute '_decision_function_w_sigmoid_from_predict'
david-cortes commented 2 years ago

Thanks for the bug report. I've updated the suggestions in the readme and docs to use cloudpickle instead.

tinyrickguy commented 2 years ago

Thanks, working fine with cloudpicke