JakeColtman / bartpy

Bayesian Additive Regression Trees For Python
https://jakecoltman.github.io/bartpy/
MIT License
219 stars 44 forks source link

TypeError trying to run test code from README.md #40

Open MattWenham opened 5 years ago

MattWenham commented 5 years ago

I'm using the master branch, and getting the following error when trying to run the basic code from README.md on my own data:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<timed exec> in <module>

~\AppData\Local\Continuum\anaconda3\envs\pymc3\lib\site-packages\bartpy\sklearnmodel.py in fit(self, X, y)
    131             self with trained parameter values
    132         """
--> 133         self.model = self._construct_model(X, y)
    134         self.extract = Parallel(n_jobs=self.n_jobs)(self.f_delayed_chains(X, y))
    135         self.combined_chains = self._combine_chains(self.extract)

~\AppData\Local\Continuum\anaconda3\envs\pymc3\lib\site-packages\bartpy\sklearnmodel.py in _construct_model(self, X, y)
    157         if len(X) == 0 or X.shape[1] == 0:
    158             raise ValueError("Empty covariate matrix passed")
--> 159         self.data = self._convert_covariates_to_data(X, y)
    160         self.sigma = Sigma(self.sigma_a, self.sigma_b, self.data.normalizing_scale)
    161         self.model = Model(self.data, self.sigma, n_trees=self.n_trees, alpha=self.alpha, beta=self.beta)

~\AppData\Local\Continuum\anaconda3\envs\pymc3\lib\site-packages\bartpy\sklearnmodel.py in _convert_covariates_to_data(X, y)
    152             X: pd.DataFrame = X
    153             X = X.values
--> 154         return Data(deepcopy(X), deepcopy(y), mask=np.zeros_like(X).astype(bool), normalize=True)
    155 
    156     def _construct_model(self, X: np.ndarray, y: np.ndarray) -> Model:

TypeError: __init__() got an unexpected keyword argument 'mask'

Removing mask=np.zeros_like(X).astype(bool) from line 154 of sklearnmodel.py eliminates the error.