ecpolley / SuperLearner

Current version of the SuperLearner R package
272 stars 72 forks source link

bartmachine updates #145

Open ecpolley opened 2 years ago

ecpolley commented 2 years ago

With the current version of bartMachine (version 1.3), should we update to use bartMachine:: build_bart_machine directly? I see a environment error when trying to locate the correct variables and will explore more, example below:

library(SuperLearner)

set.seed(23432)

training set

n <- 500 p <- 50 X1 <- matrix(rnorm(np), nrow = n, ncol = p) colnames(X1) <- paste("X", 1:p, sep="") X1 <- data.frame(X1) Y1 <- X1[, 1] + sqrt(abs(X1[, 2] X1[, 3])) + X1[, 2] - X1[, 3] + rnorm(n)

generate Library and run Super Learner

SL.library <- c("SL.glm", "SL.bartMachine", "SL.mean") test <- SuperLearner(Y = Y1, X = X1, SL.library = SL.library, verbose = TRUE, method = "method.NNLS") test

SL.mybartMachine <- function(Y, X, newX, family, obsWeights, id, num_trees = 50, num_burn_in = 250, verbose = FALSE, alpha = 0.95, beta = 2, k = 2, q = 0.9, nu = 3, num_iterations_after_burn_in = 1000, ...) { require("bartMachine") model = bartMachine:: build_bart_machine(X = X, y = Y, num_trees = num_trees, num_burn_in = num_burn_in, verbose = verbose, alpha = alpha, beta = beta, k = k, q = q, nu = nu, num_iterations_after_burn_in = num_iterations_after_burn_in, serialize = TRUE) pred <- predict(model, newX) fit <- list(object = model) class(fit) <- c("SL.bartMachine") out <- list(pred = pred, fit = fit) return(out) }

SL.library <- c("SL.glm", "SL.mybartMachine", "SL.mean") test <- SuperLearner(Y = Y1, X = X1, SL.library = SL.library, verbose = TRUE, method = "method.NNLS") test