thomaswiemann / ddml

ddml: Double/Debiased Machine Learning in R
https://thomaswiemann.com/ddml/
GNU General Public License v3.0
15 stars 1 forks source link

how to add another learner #53

Open zjwait0927 opened 10 months ago

zjwait0927 commented 10 months ago

IF I want to use randomForest learner, how can i add it in the ddml_ate ->learners -> list(). Thank you!

thomaswiemann commented 10 months ago

For random forests, I would generally recommend you use the included mdl_ranger function (see here).

Otherwise, this short tutorial illustrates how you can construct a new S3 class that is compatible with ddml. See the below code snippet for applying the tutorial to randomForest:

# Wrapper function for randomForest
mdl_randomForest <- function(y, X,
                             ntree = 100, nodesize = 1, maxnodes = NULL,
                             colsample_bytree = 0.6, subsample = 0.7,
                             replace = FALSE, ...){
  # Compute randomForest
  if(!("matrix" %in% class(X))) X <- Matrix::as.matrix(X)
  mdl_fit <- randomForest::randomForest(X, y,
                                        ntree = ntree, nodesize = nodesize,
                                        maxnodes = maxnodes,
                                        mtry = ceiling(colsample_bytree *
                                                         ncol(X)),
                                        sampsize = ceiling(subsample *
                                                             length(y)),
                                        replace = replace, ...)
  # Set custom S3 class
  class(mdl_fit) <- c("mdl_randomForest", class(mdl_fit))
  return(mdl_fit)
}#MDL_RANDOMFOREST

# Predict method for mdl_randomForest fits
predict.mdl_randomForest <- function(object, newdata = NULL, ...){
  class(object) <- class(object)[2]
  # Predict using randomForest prediction method
  as.numeric(stats::predict(object, newdata, ...))
}#PREDICT.MDL_RANDOMFOREST

You can then pass mdl_randomForest to the learners argument as you would with any other learner.