bgreenwell / pdp

A general framework for constructing partial dependence (i.e., marginal effect) plots from various types machine learning models in R.
http://bgreenwell.github.io/pdp
94 stars 12 forks source link

Fast marginal effect plots (i.e., "poor man's PDPs") #91

Closed bgreenwell closed 5 years ago

bgreenwell commented 5 years ago

Rather than averaging over the entire training set, you can fix all other features at their median (numeric features) or most frequent (categorical features) value (if there are no interaction effects, these plots will be parallel to the corresponding PDPs). This is similar in spirit to the excellent plotmo package:

exemplar <- function(object) {
  UseMethod("exemplar")
}

exemplar.data.frame <- function(object) {
  res <- as.data.frame(lapply(object, FUN = function(x) {
    if (is.numeric(x)) {
      stats::median(x, na.rm = TRUE)
    } else {
      names(which.max(table(x, useNA = "no")))
    }
  }))
  # res[] <- mapply(FUN = as, res, sapply(object, FUN = class), SIMPLIFY = FALSE)
  # res
  res <- rbind(object[1L, ] , res)  # trick to copy column classes
  res[-1L, ]
}

#
# Example
#

# Load required packages
library(ggplot2)  # for ggtitle() function
library(pdp)      # for visualizing feature effects

# Fit a random forest to the AMes housing data
set.seed(101)
ames <- AmesHousing::make_ames()
rfo <- ranger::ranger(Sale_Price ~ ., data = ames)

# Marginal-effect plot
system.time(
  p1 <- partial(
    object = rfo, 
    pred.var = "Gr_Liv_Area", 
    train = exemplar(ames), 
    pred.grid = data.frame(
      "Gr_Liv_Area" = seq(from = min(ames$Gr_Liv_Area), to = max(ames$Gr_Liv_Area), length = 100)
    ),
    plot = TRUE,
    plot.engine = "ggplot2"
  ) + ggtitle("Marginal effect plot")
)
#   user  system elapsed 
#  1.972   1.383   3.101 

# Partial dependence plot
system.time(
  p2 <- partial(
    object = rfo, 
    pred.var = "Gr_Liv_Area", 
    grid.resolution = 100,
    plot = TRUE,
    plot.engine = "ggplot2"
  ) + ggtitle("Partial dependence plot")
)
#   user  system elapsed 
# 44.232   2.273   9.572 

# Display plots side by side
grid.arrange(p1, p2, nrow = 1)

image

DeFilippis commented 5 years ago

Thanks you so much for the fantastic work on this package. It's been an absolute life-saver. Had a couple of questions for you about this:

  1. I notice you computed the system.time. Is there any chance you could display what the time savings is for a margins plot over a partial dependence plot?

  2. I'm a little confused as to the difference between a PDP and a margins plot. It looks like your custom squash function returns a dataframe of equal size to the input dataframe, except where all numeric columns are replaced by the median, and all factor columns are replaced with the mode. It then computes the normal partial procedure on this dataset.

Is the difference between what this does and what partial does is that, in the partial case, all the covariates keep their real values (instead of being fixed at their median), so you're getting predictions averaged over every real value of the other variables, rather than averaged over the median values?

If so, where does the speed-up come from?

  1. Will you be implementing this in the package? Perhaps with a "marginsPlot = TRUE" tag?
bgreenwell commented 5 years ago

Thanks @DeFilippis, glad you've found the package useful. And I've been meaning to come back to this. Responses to your questions below:

  1. Times added. Note however that I expect the time saving to be quite more dramatic for larger data sets or when computing bivariate plots.

  2. You seem to be correct in the difference between the two types of plots: marginal effect plots look at one variable (or multiple) vs. the response while holding all other features constant (e.g., at their median, etc.) PDPs, on the other hand, look at one variable (or multiple) vs. the response while taking into account the average effect of all the other features. In particular, each point on a PDP is computed as the average predictions obtained from a modified copy of the original training data. In other words, it requires scoring lots of data (albeit independently) and many calls to the prediction function. Marginal plots, on the other hand, require scoring one observation per point on the plot, so much quicker and more efficient (but less accurate than PDPs, especially when strong interactions are present).

The code above is nothing more than a way to trick the partial() function into computing a marginal effect plot. An alternative to pdp, as well as another good reference on the difference between the two types of plots, is the exit plotmo package, which refers to marginal effect plots as a poor man's partial dependence plot.

DeFilippis commented 5 years ago

Perfect! Thanks so much for the thorough and speedy replies. Really appreciate it.