SauceCat / PDPbox

python partial dependence plot toolbox
http://pdpbox.readthedocs.io/en/latest/
MIT License
840 stars 129 forks source link

x_quantile fails with binary features #38

Closed zeromh closed 3 years ago

zeromh commented 5 years ago

When I run

obj = pdp.pdp_isolate(model, X_train, X_train.columns, 'addy_change')
pdp.pdp_plot(obj, 'addy_change', plot_pts_dist=True, x_quantile=True)

where "addy_change" is a binary variable, I get the error pasted below.

The problem seems to be that count_data['xticklabels'] doesn't exist for binary variables, but when x_quantile = True, _pdp_plot looks for that key anyway.

My use case is that I'm actually looping through a large list of variables, with x_quantile set to True for all of them. I'm wondering if it would make sense for pdp_plot to ignore x_quantile=True if the variable is binary.

Barring that, it would be helpful to have a more informative error message.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/.conda/envs/checking/lib/python3.6/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   2524             try:
-> 2525                 return self._engine.get_loc(key)
   2526             except KeyError:

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 'xticklabels'

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
<ipython-input-70-faa0732e94a6> in <module>()
----> 1 pdp.pdp_plot(obj, 'addy_change', plot_pts_dist=True, x_quantile=True)

~/.conda/envs/checking/lib/python3.6/site-packages/pdpbox/pdp.py in pdp_plot(pdp_isolate_out, feature_name, center, plot_pts_dist, plot_lines, frac_to_plot, cluster, n_cluster_centers, cluster_method, x_quantile, show_percentile, figsize, ncols, plot_params, which_classes)
    414 
    415             _pdp_plot(pdp_isolate_out=pdp_plot_data[0], feature_name=feature_name_adj, pdp_ax=_pdp_ax,
--> 416                       count_ax=_count_ax, **pdp_plot_params)
    417         else:
    418             pdp_ax = plt.subplot(outer_grid[1])

~/.conda/envs/checking/lib/python3.6/site-packages/pdpbox/pdp_plot_utils.py in _pdp_plot(pdp_isolate_out, feature_name, center, plot_lines, frac_to_plot, cluster, n_cluster_centers, cluster_method, x_quantile, show_percentile, pdp_ax, count_data, count_ax, plot_params)
     97             # need to plot data distribution
     98             if x_quantile:
---> 99                 count_display_columns = count_data['xticklabels'].values
    100                 # number of grids = number of bins + 1
    101                 # count_x: min -> max + 1

~/.conda/envs/checking/lib/python3.6/site-packages/pandas/core/frame.py in __getitem__(self, key)
   2137             return self._getitem_multilevel(key)
   2138         else:
-> 2139             return self._getitem_column(key)
   2140 
   2141     def _getitem_column(self, key):

~/.conda/envs/checking/lib/python3.6/site-packages/pandas/core/frame.py in _getitem_column(self, key)
   2144         # get column
   2145         if self.columns.is_unique:
-> 2146             return self._get_item_cache(key)
   2147 
   2148         # duplicate columns & possible reduce dimensionality

~/.conda/envs/checking/lib/python3.6/site-packages/pandas/core/generic.py in _get_item_cache(self, item)
   1840         res = cache.get(item)
   1841         if res is None:
-> 1842             values = self._data.get(item)
   1843             res = self._box_item_values(item, values)
   1844             cache[item] = res

~/.conda/envs/checking/lib/python3.6/site-packages/pandas/core/internals.py in get(self, item, fastpath)
   3841 
   3842             if not isna(item):
-> 3843                 loc = self.items.get_loc(item)
   3844             else:
   3845                 indexer = np.arange(len(self.items))[isna(self.items)]

~/.conda/envs/checking/lib/python3.6/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   2525                 return self._engine.get_loc(key)
   2526             except KeyError:
-> 2527                 return self._engine.get_loc(self._maybe_cast_indexer(key))
   2528 
   2529         indexer = self.get_indexer([key], method=method, tolerance=tolerance)

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 'xticklabels'