Open zjwait0927 opened 10 months ago
For random forests, I would generally recommend you use the included mdl_ranger
function (see here).
Otherwise, this short tutorial illustrates how you can construct a new S3 class that is compatible with ddml. See the below code snippet for applying the tutorial to randomForest:
# Wrapper function for randomForest
mdl_randomForest <- function(y, X,
ntree = 100, nodesize = 1, maxnodes = NULL,
colsample_bytree = 0.6, subsample = 0.7,
replace = FALSE, ...){
# Compute randomForest
if(!("matrix" %in% class(X))) X <- Matrix::as.matrix(X)
mdl_fit <- randomForest::randomForest(X, y,
ntree = ntree, nodesize = nodesize,
maxnodes = maxnodes,
mtry = ceiling(colsample_bytree *
ncol(X)),
sampsize = ceiling(subsample *
length(y)),
replace = replace, ...)
# Set custom S3 class
class(mdl_fit) <- c("mdl_randomForest", class(mdl_fit))
return(mdl_fit)
}#MDL_RANDOMFOREST
# Predict method for mdl_randomForest fits
predict.mdl_randomForest <- function(object, newdata = NULL, ...){
class(object) <- class(object)[2]
# Predict using randomForest prediction method
as.numeric(stats::predict(object, newdata, ...))
}#PREDICT.MDL_RANDOMFOREST
You can then pass mdl_randomForest
to the learners argument as you would with any other learner.
IF I want to use randomForest learner, how can i add it in the ddml_ate ->learners -> list(). Thank you!