grf-labs / grf

Generalized Random Forests
https://grf-labs.github.io/grf/
GNU General Public License v3.0
971 stars 249 forks source link

varImpPlot for causal trees #130

Closed jsrich closed 7 years ago

jsrich commented 7 years ago

I think this package is really interesting. It would be nice if there was a function to do a varImpPlot, similar to what is available in the randomForest package. Right now there is nothing I am aware of that would allow me to identify which variables have the largest impact on the heterogeneity of the effect after fitting a causal forest.

swager commented 7 years ago

Agreed, that would be nice to have. On thing we do already have is the function split_frequencies, which shows how often the forest chose to split on each feature at different depths. But of course it'd be great to have more plotting tools.

franoteiza commented 7 years ago

Hi! Related to this question and your answer. The split_frequencies function shows "how often the forest chose to split on each feature at different depths". I understand that by "different depths" you refer to the depth within each tree. Is there any way to infer from this function at which values/levels of each variable the trees tended to split?

jtibshirani commented 7 years ago

Hi @franoteiza, unfortunately there is currently no first-class way to see which feature values the forest tended to split on. Raw splitting information is available through the get_tree function though -- you could load each tree individually, traverse through the nodes, and look at the values for split_variable and split_value.

We agree that this would be very valuable functionality to have -- thank you both for the feedback!

franoteiza commented 7 years ago

Thanks @jtibshirani for getting back to me! Unfortunately my coding skills are too weak to be able to help out here. Thanks for the package anyway, its very practical and easy to use.

nredell commented 7 years ago

Per the direction from @jtibshirani above, below is an R function split_var_info() using data.table that traverses each tree and collects the variable, split value, and tree number for each split (the non-leaves). Running it on the first 1:100 trees from the 4000-tree tau.forest example from this repo shows that it's awesomely slow code :)

The returned data.table object doesn't grab the split depth, but it gives enough info to calculate summary stats and plots across trees. Warning: because this function collects split info from all nodes, it will scale horribly with increasing sample size.

split_var_info <- function(forest, num.trees = NULL) {

  # num.trees selects the first 1 to N trees given in the argument or, if NULL, all trees in the forest
  if(missing(num.trees)) {
    num.trees <- forest$num.trees
  }

  # Create a list to hold the tree-level results
  temp <- vector("list", num.trees)

  for(i in 1:num.trees) {
    temp[[i]] <- get_tree(forest, i)$nodes

    temp[[i]] <- lapply(temp[[i]], function(j) {
      if(j$is_leaf == FALSE) {
        data.frame("variable" = j$split_variable, "value" = j$split_value, "tree" = i)
      } else {
        NULL
      }})
    temp[[i]] <- data.table::rbindlist(temp[[i]])
  }
  dataOut <- data.table::rbindlist(temp)
  return(dataOut)
}

dataVarInfo <- split_var_info(tau.forest, num.trees = 100)
franoteiza commented 7 years ago

Just tried it on my forest and its exactly what I was looking for, thanks!!

dswatson commented 7 years ago

@jsrich Here's a lightweight function that generates something along the lines of the randomForest::varImpPlot output:

var_imp_plot <- function(forest, decay.exponent = 2L, max.depth = 4L) {

  # Calculate variable importance of all features
  # (from print.R)
  split.freq <- split_frequencies(forest, max.depth)
  split.freq <- split.freq / pmax(1L, rowSums(split.freq))
  weight <- seq_len(nrow(split.freq)) ^ -decay.exponent
  var.importance <- t(split.freq) %*% weight / sum(weight)

  # Format data frame
  require(dplyr)
  if (is(forest, 'regression_forest') || is(forest, 'quantile_forest')) {
    p <- ncol(forest$original.data) - 1L
  } else if (is(forest, 'causal_forest')) {
    p <- ncol(forest$original.data) - 2L
  } else if (is(forest, 'instrumental_forest')) {
    p <- ncol(forest$original.data) - 3L
  }
  var.names <- colnames(forest$original.data)[seq_len(p)]
  if (is.null(var.names)) {
    var.names <- paste0('x', seq_len(p))
  }
  df <- data_frame(Variable = var.names,
                 Importance = as.numeric(var.importance)) %>%
    arrange(Importance) %>% 
    mutate(Variable = factor(Variable, levels = unique(Variable)))

  # Plot results
  require(ggplot2)
  p <- ggplot(df, aes(Variable, Importance)) + 
    geom_bar(stat = 'identity') + 
    coord_flip() + 
    ggtitle('Variable Importance Plot') + 
    theme_bw() + 
    theme(plot.title = element_text(hjust = 0.5))
  print(p)

}

Easy to customize that ggplot figure with different colors, backgrounds, etc., but hopefully this should work as a first pass.

swager commented 7 years ago

Thanks @nredell and @dswatson for the contributions! These are very nice. For the second one, it'd be interesting to understand better how these plots compare to the permutation-based ones you mentioned in #136.

dswatson commented 7 years ago

@swager Agreed! I believe randomForest generates plots for both measures by default. It would be simple to add a similar option to the function above if permutation importance were implemented.

jsrich commented 7 years ago

Thank you @dswatson and everyone for your thoughts and input. The var_imp_plot function works great.

aliahmedawan commented 6 years ago

Thank you for creating this amazing package in R. I am using a dataset with categorical variables having large number of categories in them. Is there anyway to use this package with variables having large number of categories? When I input the the data with large categorical variables, I get the following error.

Error in regression_train(data$default, data$sparse, outcome.index, as.numeric(tunable.params["mtry"])

jtibshirani commented 6 years ago

Thanks @aliahmedawan for the question. As general guidance, it's best to open an issue if you have a new question, as opposed to commenting on an old one.

As a brief answer, we're planning to add support for categorical variables, and are tracking the issue in https://github.com/swager/grf/issues/109. In the meantime, you might try out the approach suggested in https://github.com/swager/grf/issues/27. However, since you have a large number of categorical variables, it may make sense to simply convert the categories to integers and see how the results look.

aliahmedawan commented 6 years ago

Thank you @jtibshirani I appreciate your reply. We are trying various things such using integers for categories and also grouping the categorical variables and taking averages of their outcomes. Thank you much!