bgreenwell / pdp

A general framework for constructing partial dependence (i.e., marginal effect) plots from various types machine learning models in R.
http://bgreenwell.github.io/pdp
93 stars 12 forks source link

Is it possible to add color gradient to 3d plots? #93

Closed DeFilippis closed 5 years ago

DeFilippis commented 5 years ago

I'm trying to add a color gradient to a 3d two-way PDP plot like so:

rwb <- colorRampPalette(c("red", "white", "blue"))
plotPartial(pdp_mtg_duration_attd_interaction, levelplot = FALSE, zlab = "revenue_log_s", 
            colorkey = TRUE, 
            col.regions = rwb)

However, this isn't working. I don't see any examples in your vignettes adding color to the 3d plot -- is this not possible?

bgreenwell commented 5 years ago

Hi @DeFilippis, yes this is certainly possible and you nearly have it. colorRampPalette() returns a function that takes an integer specifying the number of colors to interpolate; so use, for example, rwb(100) instead of rwb. Below is a reproducible example:

# Load required packages
library(pdp)     # for partial dependence plots
library(ranger)  # for fitting random forests

# Simulate training data from the Friedman 1 benchmark problem
set.seed(101)  # for reproducibility
trn <- as.data.frame(mlbench::mlbench.friedman1(500, sd = 0.1))
rfo <- ranger(y ~ ., data = trn)

# Compute bivariate partial dependence of x.1 and x.2 on y
pd <- partial(rfo, pred.var = c("x.1", "x.2"), progress = "text")

# Construct PDP
cols <- colorRampPalette(c("red", "blue", "green"))(100)
plotPartial(pd, levelplot = FALSE, drape = TRUE, col.regions = cols)

image

bgreenwell commented 5 years ago

@DeFilippis note that you should also specify drape = TRUE in the call to plotPartial(), as in the example above.

DeFilippis commented 5 years ago

Wow, this is gorgeous! Thank you so much. I had never seen anything about drape, what does it do?

bgreenwell commented 5 years ago

No problem! It's useful to know that plotPartial() just calls plotting functions from the underrated (and built-in) lattice package. For 3-D plots (i.e., levelplot = FALSE) it calls lattice::wireframe(). The drape argument to this function is just a logical to indicate whether or not to "drape" the plot in color, which can be specified via the col.regions argument. pdp, by default, uses the wonderful viridis color palatte seen in matplotlib. You can change the palette used via the palette argument (see ?pdp::plotPartial and ??pdp::autoplot.partial for details).