david-cortes / contextualbandits

Python implementations of contextual bandits algorithms
http://contextual-bandits.readthedocs.io
BSD 2-Clause "Simplified" License
750 stars 148 forks source link

Problem with explore_rounds in online.py #9

Closed miko3333 closed 5 years ago

miko3333 commented 5 years ago
def _predict(self, X, exploit = False):
        X = _check_X_input(X)

        if X.shape[0] == 0:
            return np.array([])

        if exploit:
            return self._oracles.predict(X)

        if self.explore_cnt < self.explore_rounds:
            self.explore_cnt += X.shape[0]

            # case 1: all predictions are within allowance
            if self.explore_cnt <= self.explore_rounds:
                return np.random.randint(self.nchoices, size = X.shape[0])

            # case 2: some predictions are within allowance, others are not
            else:
                n_explore = self.explore_rounds - self.explore_cnt
                pred = np.zeros(X.shape[0])
                pred[:n_explore] = np.random.randint(self.nchoices, n_explore)
                pred[n_explore:] = self._oracles.predict(X)
                return pred
        else:
            return self._oracles.predict(X)

This part of the code in online.py results in (low>=high) error in case 2. Guess the problem is that n_explore being negative.

david-cortes commented 5 years ago

Thanks for the report! Fixed now and pushed to PyPI, please try again and reopen if there's anther issue.

david-cortes commented 5 years ago

My bad, I put the wrong sign again in the code, updated yet again.

miko3333 commented 5 years ago

New problem occured. pred[n_explore:] = self._oracles.predict(X[n_explore]) should be pred[n_explore:] = self._oracles.predict(X[n_explore:]), otherwise it raises index error.

david-cortes commented 5 years ago

Thanks again for spotting the error. Updated now.