g3rv4 / RandomForest

A tiny library to run in C# random forests created in R
MIT License
12 stars 2 forks source link

Serialize the random forests #3

Open g3rv4 opened 5 years ago

g3rv4 commented 5 years ago

instead of loading the PMMLs every time. Protobuf?

NickLarsen commented 5 years ago

The point of using PMML is that you can directly import a file exported by an R model. Whatever other options you consider they should probably fill the same requirement. With that in mind, I have a proposal.

There is an alternate representation of a decision tree that essentially creates a list of predicate descriptions where each index is a different node in the tree. The randomForest R package has a function called getTree which will return you a data frame in this format. With a quick function, you can combine the data frames for each tree in the forest to a single data frame.

install.packages("randomForest")
install.packages("tidyverse")
library(tidyverse)
library(randomForest)

data_train <- read.csv("https://raw.githubusercontent.com/thomaspernet/data_csv_r/master/data/titanic_train.csv") %>%
  select(-1)

set.seed(234579)
rf_model = randomForest(survived ~ ., data = data_train, ntree=10)

extract_trees = function(rf) {
  treeNums = seq(rf$ntree)
  get_single_tree = function(treeNum) {
    getTree(rf, k=treeNum, labelVar = TRUE) %>%
      mutate(TreeNum = treeNum)
  }
  lapply(treeNums, get_single_tree) %>% bind_rows()
}

flat_rf = extract_trees(rf_model)
write_csv(flat_rf, "rf-titanic-survived.csv")

The header looks like:

"left daughter","right daughter","split var","split point","status","prediction","TreeNum"

The split var field is the name of the field to check for this tree node and split point is the value to compare against always using less than. Leaf nodes have a prediction that is not null. Left and right daughters are indexes to the row numbers in the list for the left and right child nodes.

This data frame can be stored efficiently in just about any format, CSV, feather, protobuf, whatever and all you have to do is swap out the line that writes out the file.


I should note that there are other benefits to this storage method other than loading the model quickly. If you rewrite the tree implementation to directly store this list of predicates, the code will be more direct and easy to read and you will likely get a notable performance boost due to enormously more efficient memory usage.

g3rv4 commented 5 years ago

👍 that's beautiful! I was thinking I'd just protobuf the trees after initially loading the PMML file, but this is way nicer... also, just killing the need of using JAVA makes it extremely worth it!