david-cortes / contextualbandits

Python implementations of contextual bandits algorithms
http://contextual-bandits.readthedocs.io
BSD 2-Clause "Simplified" License
751 stars 148 forks source link

AssertionError: online_contextual_bandits.ipynb #62

Closed shyun46 closed 2 years ago

shyun46 commented 2 years ago

First of all thank you for code to use CB : >

When I run your example notebook (online_contextual_bandits.ipynb), I get 'AssertionError' when i run '3.3 Streaming models' part. how can i get some hint to fix that error?

` AssertionError:

AssertionError Traceback (most recent call last)

in 62 lst_actions[model], 63 X_batch, y_batch, ---> 64 rnd_seed = batch_st) in simulate_rounds_stoch(model, rewards, actions_hist, X_batch, y_batch, rnd_seed) 31 32 ## choosing actions for this batch ---> 33 actions_this_batch = model.predict(X_batch).astype('uint8') 34 35 # keeping track of the sum of rewards received /databricks/python/lib/python3.7/site-packages/contextualbandits/online.py in predict(self, X, exploit) 2003 if not self.is_fitted: 2004 return self._predict_random_if_unfit(X, False) -> 2005 return self._name_arms(self._predict(X, exploit, True)) 2006 2007 def _predict(self, X, exploit = False, choose = True): /databricks/python/lib/python3.7/site-packages/contextualbandits/online.py in _predict(self, X, exploit, choose) 2029 # case 1: number of predictions to make would still fit within current window 2030 if remainder_window > X.shape[0]: -> 2031 pred, pred_max = self._calc_preds(X, choose) 2032 self.window_cnt += X.shape[0] 2033 self.window = np.r_[self.window, pred_max] /databricks/python/lib/python3.7/site-packages/contextualbandits/online.py in _calc_preds(self, X, choose) 2076 2077 def _calc_preds(self, X, choose = True): -> 2078 pred_proba = self._oracles.decision_function(X) 2079 np.nan_to_num(pred_proba, copy=False) 2080 pred_max = pred_proba.max(axis = 1) /databricks/python/lib/python3.7/site-packages/contextualbandits/utils.py in decision_function(self, X) 927 Parallel(n_jobs=self.njobs, verbose=0, require="sharedmem")\ 928 (delayed(self._decision_function_single)(choice, X, preds, 1) \ --> 929 for choice in range(self.n)) 930 _apply_smoothing(preds, self.smooth, self.counters, 931 self.noise_to_smooth, self.random_state) /databricks/python/lib/python3.7/site-packages/joblib/parallel.py in __call__(self, iterable) 1015 1016 with self._backend.retrieval_context(): -> 1017 self.retrieve() 1018 # Make sure that we get a last message telling us we are done 1019 elapsed_time = time.time() - self._start_time /databricks/python/lib/python3.7/site-packages/joblib/parallel.py in retrieve(self) 907 try: 908 if getattr(self._backend, 'supports_timeout', False): --> 909 self._output.extend(job.get(timeout=self.timeout)) 910 else: 911 self._output.extend(job.get()) /usr/lib/python3.7/multiprocessing/pool.py in get(self, timeout) 655 return self._value 656 else: --> 657 raise self._value 658 659 def _set(self, i, obj): /usr/lib/python3.7/multiprocessing/pool.py in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception) 119 job, i, func, args, kwds = task 120 try: --> 121 result = (True, func(*args, **kwds)) 122 except Exception as e: 123 if wrap_exception and func is not _helper_reraises_exception: /databricks/python/lib/python3.7/site-packages/joblib/_parallel_backends.py in __call__(self, *args, **kwargs) 606 def __call__(self, *args, **kwargs): 607 try: --> 608 return self.func(*args, **kwargs) 609 except KeyboardInterrupt: 610 # We capture the KeyboardInterrupt and reraise it as /databricks/python/lib/python3.7/site-packages/joblib/parallel.py in __call__(self) 254 with parallel_backend(self._backend, n_jobs=self._n_jobs): 255 return [func(*args, **kwargs) --> 256 for func, args, kwargs in self.items] 257 258 def __len__(self): /databricks/python/lib/python3.7/site-packages/joblib/parallel.py in (.0) 254 with parallel_backend(self._backend, n_jobs=self._n_jobs): 255 return [func(*args, **kwargs) --> 256 for func, args, kwargs in self.items] 257 258 def __len__(self): /databricks/python/lib/python3.7/site-packages/contextualbandits/utils.py in _decision_function_single(self, choice, X, preds, depth) 955 preds[:, choice] = self.algos[choice].decision_function_w_sigmoid(X) 956 else: --> 957 preds[:, choice] = self.algos[choice].predict(X) 958 959 ### Note to self: it's not a problem to mix different methods from the /databricks/python/lib/python3.7/site-packages/contextualbandits/linreg/__init__.py in predict(self, X) 512 The predicted values given 'X'. 513 """ --> 514 assert self.is_fitted_ 515 516 pred = X.dot(self.coef_[:self._n]) AssertionError: `
david-cortes commented 2 years ago

Thanks for the bug report. Should be fixed now.