softwaredeng / inTrees

inTrees package
38 stars 12 forks source link

How to extract rules from a caret-trained random forest? #16

Open Lmijnen opened 4 years ago

Lmijnen commented 4 years ago

Hello,

I would like to extract rules from a model trained through the CARET train function. I used the method Ranger.

However, the object which is created does not fit to the RF2list or Ranger2List functions:

The code:

data 1

data(iris) X <- within(iris,rm("Species")); Y <- iris[,"Species"]

rf_train <- train(Species ~ ., data = iris, method = "ranger")

tree_list <- RF2List(rf_train) Error in vector("list", rf$ntree) : invalid 'length' argument

tree_list <- Ranger2List(rf_train) Error in vector("list", rf_ranger$num.trees) : invalid 'length' argument

I presume there is some glue code necessary before rules can be extracted from my CARET S3 train object. Can anyone help me with this?

LucindaLanoy commented 2 months ago

Hello, I know this is an old issue and I'm not involved in the inTrees package so I hope it's ok for me to answer this, but I thought I should do it in case anyone had a similar problem and needed help with that.

Ranger2List needs specifically a ranger object, the caret object created with train stores the ranger model inside the finalModel element. Same thing if you used method = "rf" inside the train function instead of "ranger", you'll find the random forest object inside finalModel (which is what you need for RF2List).

So if you simply write tree_list <- Ranger2List(rf_train$finalModel) or tree_list <- Ranger2List(rf_train[["finalModel"]]) it will work (it does with me at least when I run this code) because you'll give the ranger model specifically needed by the Ranger2List function instead of the whole caret object.

Just in case, I used the Ranger2List function as written in issue #15 with the modification proposed in issue #3, since I couldn't find it anywhere else, which is as follows :

Ranger2List <- function(rf_ranger)
  {
     formatRanger <- function(tree){
         rownames(tree) <- 1:nrow(tree)
         tree$status <- ifelse(tree$terminal==TRUE,-1,1)
         tree$`left daughter` <- tree$leftChild + 1
         tree$`right daughter` <- tree$rightChild + 1
         tree$`split var` <- tree$splitvarID + 1
         tree$`split point` <- tree$splitval
         tree$prediction <- tree$prediction
         tree <- tree[,c("left daughter","right daughter","split var","split point","status")]
         tree <- as.data.frame(tree)
         return(tree)
     }
     treeList <- NULL
     treeList$ntree <- rf_ranger$num.trees
     treeList$list <- vector("list",rf_ranger$num.trees)
     for(i in 1:rf_ranger$num.trees){
         treeList$list[[i]] <- formatRanger( treeInfo(rf_ranger, tree = i) )
     }
     return(treeList)
}

I hope I explained well, I'm still a bit new to ML and GitHub so I may not use the correct terms and I apologize for that.