Closed jsrich closed 7 years ago
Agreed, that would be nice to have. On thing we do already have is the function split_frequencies
, which shows how often the forest chose to split on each feature at different depths. But of course it'd be great to have more plotting tools.
Hi! Related to this question and your answer. The split_frequencies
function shows "how often the forest chose to split on each feature at different depths". I understand that by "different depths" you refer to the depth within each tree. Is there any way to infer from this function at which values/levels of each variable the trees tended to split?
Hi @franoteiza, unfortunately there is currently no first-class way to see which feature values the forest tended to split on. Raw splitting information is available through the get_tree
function though -- you could load each tree individually, traverse through the nodes, and look at the values for split_variable
and split_value
.
We agree that this would be very valuable functionality to have -- thank you both for the feedback!
Thanks @jtibshirani for getting back to me! Unfortunately my coding skills are too weak to be able to help out here. Thanks for the package anyway, its very practical and easy to use.
Per the direction from @jtibshirani above, below is an R function split_var_info()
using data.table
that traverses each tree and collects the variable, split value, and tree number for each split (the non-leaves). Running it on the first 1:100 trees from the 4000-tree tau.forest example from this repo shows that it's awesomely slow code :)
The returned data.table object doesn't grab the split depth, but it gives enough info to calculate summary stats and plots across trees. Warning: because this function collects split info from all nodes, it will scale horribly with increasing sample size.
split_var_info <- function(forest, num.trees = NULL) {
# num.trees selects the first 1 to N trees given in the argument or, if NULL, all trees in the forest
if(missing(num.trees)) {
num.trees <- forest$num.trees
}
# Create a list to hold the tree-level results
temp <- vector("list", num.trees)
for(i in 1:num.trees) {
temp[[i]] <- get_tree(forest, i)$nodes
temp[[i]] <- lapply(temp[[i]], function(j) {
if(j$is_leaf == FALSE) {
data.frame("variable" = j$split_variable, "value" = j$split_value, "tree" = i)
} else {
NULL
}})
temp[[i]] <- data.table::rbindlist(temp[[i]])
}
dataOut <- data.table::rbindlist(temp)
return(dataOut)
}
dataVarInfo <- split_var_info(tau.forest, num.trees = 100)
Just tried it on my forest and its exactly what I was looking for, thanks!!
@jsrich Here's a lightweight function that generates something along the lines of the randomForest::varImpPlot
output:
var_imp_plot <- function(forest, decay.exponent = 2L, max.depth = 4L) {
# Calculate variable importance of all features
# (from print.R)
split.freq <- split_frequencies(forest, max.depth)
split.freq <- split.freq / pmax(1L, rowSums(split.freq))
weight <- seq_len(nrow(split.freq)) ^ -decay.exponent
var.importance <- t(split.freq) %*% weight / sum(weight)
# Format data frame
require(dplyr)
if (is(forest, 'regression_forest') || is(forest, 'quantile_forest')) {
p <- ncol(forest$original.data) - 1L
} else if (is(forest, 'causal_forest')) {
p <- ncol(forest$original.data) - 2L
} else if (is(forest, 'instrumental_forest')) {
p <- ncol(forest$original.data) - 3L
}
var.names <- colnames(forest$original.data)[seq_len(p)]
if (is.null(var.names)) {
var.names <- paste0('x', seq_len(p))
}
df <- data_frame(Variable = var.names,
Importance = as.numeric(var.importance)) %>%
arrange(Importance) %>%
mutate(Variable = factor(Variable, levels = unique(Variable)))
# Plot results
require(ggplot2)
p <- ggplot(df, aes(Variable, Importance)) +
geom_bar(stat = 'identity') +
coord_flip() +
ggtitle('Variable Importance Plot') +
theme_bw() +
theme(plot.title = element_text(hjust = 0.5))
print(p)
}
Easy to customize that ggplot figure with different colors, backgrounds, etc., but hopefully this should work as a first pass.
Thanks @nredell and @dswatson for the contributions! These are very nice. For the second one, it'd be interesting to understand better how these plots compare to the permutation-based ones you mentioned in #136.
@swager Agreed! I believe randomForest
generates plots for both measures by default. It would be simple to add a similar option to the function above if permutation importance were implemented.
Thank you @dswatson and everyone for your thoughts and input. The var_imp_plot function works great.
Thank you for creating this amazing package in R. I am using a dataset with categorical variables having large number of categories in them. Is there anyway to use this package with variables having large number of categories? When I input the the data with large categorical variables, I get the following error.
Error in regression_train(data$default, data$sparse, outcome.index, as.numeric(tunable.params["mtry"])
Thanks @aliahmedawan for the question. As general guidance, it's best to open an issue if you have a new question, as opposed to commenting on an old one.
As a brief answer, we're planning to add support for categorical variables, and are tracking the issue in https://github.com/swager/grf/issues/109. In the meantime, you might try out the approach suggested in https://github.com/swager/grf/issues/27. However, since you have a large number of categorical variables, it may make sense to simply convert the categories to integers and see how the results look.
Thank you @jtibshirani I appreciate your reply. We are trying various things such using integers for categories and also grouping the categorical variables and taking averages of their outcomes. Thank you much!
I think this package is really interesting. It would be nice if there was a function to do a varImpPlot, similar to what is available in the randomForest package. Right now there is nothing I am aware of that would allow me to identify which variables have the largest impact on the heterogeneity of the effect after fitting a causal forest.