SauceCat / PDPbox

python partial dependence plot toolbox
http://pdpbox.readthedocs.io/en/latest/
MIT License
840 stars 129 forks source link

pdp_interact_plot dimension reference subplot out of alignment. #55

Closed dyerrington closed 3 years ago

dyerrington commented 5 years ago

Here is my code to reproduce the problem:

from pdpbox import pdp, get_dataset, info_plots
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import matplotlib.pyplot as plt

%matplotlib inline

# Setup data
data = load_iris()
df = pd.DataFrame(data.data, columns = data.feature_names)
df.index = data.target

# Train basic model
estimator = RandomForestClassifier()
model = estimator.fit(df, df.index)

#  pdp_interactions
pdp_paid= pdp.pdp_interact(
    model=model, dataset=df, model_features=df.columns, features=df.columns, 
    num_grid_points=[5, 5, 5], 
    percentile_ranges=[None, None, None], 
    n_jobs=4
)

# plotting
fig, axes = pdp.pdp_interact_plot(
    pdp_paid, ['petal length (cm)', 'petal width (cm)'], plot_type='grid',x_quantile=True, ncols=2, plot_pdp=True, 
    which_classes=[0, 1, 2]
)

image


The problem is that in the reference docs you have, these subplots that show the dimensional values to the left and above each class plot, they are aligned with the grid of the figure. They seem to be squished. I can probably figure out how to reference to axis or figure directly and correct them but is this expected? Any easy fix?

Thanks! Great library!

dyerrington commented 5 years ago

Since axes are returned from figure object, while not a great solution that may not work for everyone, I've manually adjusted the axis ratios. This snippet will update the aspect of any size plot.

aspect_x, aspect_y = (.6, 1.6) # Update to your liking
update_aspect = lambda ax: (ax['_pdp_x_ax'].set_aspect(aspect_x), ax['_pdp_y_ax'].set_aspect(aspect_y))
map_aspect = np.vectorize(update_aspect)
map_aspect(axes['pdp_inter_ax']);
SauceCat commented 3 years ago

It's fixed in the next version.