ecpolley / SuperLearner

Current version of the SuperLearner R package
272 stars 72 forks source link

Problem when performing prediction using poredict.SuperLearner #133

Closed william-denault closed 3 years ago

william-denault commented 3 years ago

Dear Eric Polley,

I am trying to use your R implementation of SuperLearner to perform some prediction. However, I face an incomprehensible error when using predict.SuperLearner function. It seems not to accept the data I plug in it. However, these are exactly the same data I use for the training. I got the following error "Error in array(x, c(length(x), 1L), if (!is.null(names(x))) list(names(x), : 'data' must be of a vector type, was 'NULL'" when I run the following script

`rm(list=ls()) library(SuperLearner) library(polspline) library(xgboost) library(glmnet) library(ranger)

BART machine

binary outcome

set.seed(1) N <- 200 X <- matrix(rnorm(N10), N, 10) X <- as.data.frame(X) colnames(X) <- LETTERS[1:10] Y <- rbinom(N, 1, plogis(.2X[, 1] + .1X[, 2] - .2X[, 3] + .1X[, 3]X[, 4] - .2*abs(X[, 4]))) x_tr <- X[1:150,] y_tr <- Y[1:150]

tune = list(ntrees = c(10, 20, 50,100), max_depth = 2:5, shrinkage = c(0.001, 0.01, 0.1))

Set detailed names = T so we can see the configuration for each function.

Also shorten the name prefix.

learners = create.Learner("SL.xgboost", tune = tune, detailed_names = TRUE, name_prefix = "xgb")

48 configurations -

length(learners$names)

10 fold cross_validation

SL = SuperLearner(Y = y_tr, X = x_tr, family = binomial(),#weights=my_w,

For a real analysis we would use V = 10.

                      method = "method.AUC",
                      SL.library = c("SL.glmnet",
                                     learners$names,
                                     "SL.ranger",
                                     "SL.knn",
                                     "SL.nnet"))

str(x_tr)

predict(SL,newdata=x_tr)`

However, the predict.SuperLearner function works fine on this example. `#install.packages(c("caret", "glmnet", "randomForest", "ggplot2", "RhpcBLASctl"))

Load a dataset from the MASS package.

data(Boston, package = "MASS")

Review info on the Boston dataset.

?MASS::Boston

Check for any missing data - looks like we don't have any.

colSums(is.na(Boston)) outcome = Boston$medv

Create a dataframe to contain our explanatory variables.

data = subset(Boston, select = -medv)

Check structure of our dataframe.

str(data)

dim(data)

set.seed(1)

Reduce to a dataset of 150 observations to speed up model fitting.

train_obs = sample(nrow(data), 150)

X is our training sample.

x_train = data[train_obs, ]

Create a holdout set for evaluating model performance.

Note: cross-validation is even better than a single holdout sample.

x_holdout = data[-train_obs, ]

Create a binary outcome variable: towns in which median home value is > 22,000.

outcome_bin = as.numeric(outcome > 22)

y_train = outcome_bin[train_obs] y_holdout = outcome_bin[-train_obs]

Review the outcome variable distribution.

table(y_train, useNA = "ifany")

library(SuperLearner) listWrappers() SL.bartMachine set.seed(1)

Fit lasso model.

set.seed(1) sl = SuperLearner(Y = y_train, X = x_train, family = binomial(), SL.library = c("SL.mean", "SL.glmnet", "SL.ranger","SL.nnet")) sl pred = predict(sl, x_holdout, onlySL = TRUE)

Check the structure of this prediction object.

str(pred) str(x_holdout)`

It seems that your package is pretty neatly made, and I would like to use it for my research. Do you have any idea of why the prediction function does not work in the first example?

Best regards, William Denault

ecpolley commented 3 years ago

Hi William, some of the SuperLearner predict methods require the original training data to be provided when predicting on new observations, one of those is the SL.knn algorithm. In your example, when you call predict, also provide the X and Y values with:

predict(SL, newdata=x_tr, X = x_tr, Y = y_tr)

The error message could be better in these cases, I think we could add a check to the predict.SL.knn with a better message to avoid this in the future. Thanks for the note.

william-denault commented 3 years ago

Thank you for your quick feedback. It was actually because of the knn.

Great package :)