david-cortes / contextualbandits

Python implementations of contextual bandits algorithms
http://contextual-bandits.readthedocs.io
BSD 2-Clause "Simplified" License
739 stars 143 forks source link

ParametricTS fails with: '_OneVsRest' object has no attribute 'beta_counters' #53

Closed tinyrickguy closed 2 years ago

tinyrickguy commented 2 years ago

trying to use ParametricTS causes the following error.

usage:

base_model = XGBRegressor(n_estimators=20)
cb_model = cb.online.ParametricTS(base_model, nchoices=actions)
...
cb_model.predict(df_context)

env: contextualbandits 0.3.17.post3 (installed via pip) Python 3.9.13 OS: Mac OSX 12.5.1

Stack trace:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/miniconda3/envs/env1/lib/python3.9/site-packages/contextualbandits/online.py in predict(self, X, exploit, output_score)
    608             scores = self._exploit(X)
    609         else:
--> 610             scores = self.decision_function(X)
    611         pred = self._name_arms(np.argmax(scores, axis = 1))
    612 

~/miniconda3/envs/env1/lib/python3.9/site-packages/contextualbandits/online.py in decision_function(self, X)
    507             else:
    508                 return self._predict_from_beta_prior_and_smoothing(X.shape[0])
--> 509         return self._score_matrix(X)
    510 
    511     def _predict_from_beta_prior_and_smoothing(self, n):

~/miniconda3/envs/env1/lib/python3.9/site-packages/contextualbandits/online.py in _score_matrix(self, X)
   3341     def _score_matrix(self, X):
   3342         pred = self._oracles.decision_function(X)
-> 3343         counters = self._oracles.get_nobs_by_arm()
   3344         with_model = counters >= self.beta_prior[1]
   3345         counters = counters.reshape((1,-1))

~/miniconda3/envs/env1/lib/python3.9/site-packages/contextualbandits/utils.py in get_nobs_by_arm(self)
   1006 
   1007     def get_nobs_by_arm(self):
-> 1008         return self.beta_counters[1] + self.beta_counters[2]
   1009 
   1010     def exploit(self, X):

AttributeError: '_OneVsRest' object has no attribute 'beta_counters'
david-cortes commented 2 years ago

Thanks for the bug report. This is now fixed in the latest version:

pip install -U contextualbandits
tinyrickguy commented 2 years ago

Thank you for the quick fix!