ecpolley / SuperLearner

Current version of the SuperLearner R package
272 stars 72 forks source link

SL.dbarts modification #102

Closed benkeser closed 6 years ago

benkeser commented 6 years ago

Hi all, Was fooling around with SL.dbarts and noticed that it sometimes returns values outside (0,1) when Y is binary. Seems like bart wants a factor input for the outcome variable if it's to properly handle (0,1) outcomes. Here's a simple modification that does the trick.

SL.dbarts = function(Y, X, newX, family, obsWeights, id,
                     sigest = NA,
                     sigdf = 3,
                     sigquant = 0.90,
                     k = 2.0,
                     power = 2.0,
                     base = 0.95,
                     binaryOffset = 0.0,
                     ntree = 200,
                     ndpost = 1000,
                     nskip = 100,
                     printevery = 100,
                     keepevery = 1,
                     keeptrainfits = T,
                     usequants = F,
                     numcut = 100,
                     printcutoffs = 0,
                     nthread = 1,
                     keepcall = T,
                     verbose = F,
                     ...) {

  .SL.require("dbarts")
if(family$family == "binomial"){
  y.train <- factor(Y)
}else{
  y.train <- Y
}
  model =
    dbarts::bart(x.train = X,
                 y.train = y.train,
                 # We need to pass newX in directly due to lack of prediction.
                 x.test = newX,
                 sigest = sigest,
                 sigdf = sigdf,
                 sigquant = sigquant,
                 k = k,
                 power = power,
                 base = base,
                 binaryOffset = binaryOffset,
                 weights = obsWeights,
                 ntree = ntree,
                 ndpost = ndpost,
                 nskip = nskip,
                 printevery = printevery,
                 keepevery = keepevery,
                 keeptrainfits = keeptrainfits,
                 usequants = usequants,
                 numcut = numcut,
                 printcutoffs = printcutoffs,
                 nthread = nthread,
                 keepcall = keepcall,
                 verbose = verbose)

  # TODO: there is no predict!
  #pred = predict(model, newdata = newX)
  if (family$family == "gaussian") {
    pred = model$yhat.test.mean
  } else {
    # No mean is provided for binary Y :/
    pred = colMeans(model$yhat.test)
  }

  fit = list(object = model)
  class(fit) = c("SL.dbarts")
  out = list(pred = pred, fit = fit)
  return(out)
}
benkeser commented 6 years ago

Just kidding, now seeing that bart claims it fits probit link if all Y are in 0,1. But it still returns values outside (0,1). Seems like a bart problem, not a SuperLearner problem. Closing the issue...

ck37 commented 6 years ago

well BartMachine definitely has this problem - I have a fix in my ck37r package (BartMachine2) but haven't had time to do a pr yet.

Just curious, are you using dbarts because bartMachine is a pain to get working, or for other reasons?

-ck.mobile

On Sep 26, 2017, at 10:04 AM, David Benkeser notifications@github.com wrote:

Just kidding, now seeing that bart claims it fits probit link if all Y are in 0,1. But it still returns values outside (0,1). Seems like a bart problem, not a SuperLearner problem. Closing the issue...

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub, or mute the thread.

benkeser commented 6 years ago

Alright, sorry for being obnoxious and spamming here, but just noticed in the bart documentation that if binary outcomes are used then the output of bart needs to be pnorm'ed to get back to the probability scale. Thus, this modification should work:

SL.dbarts = function(Y, X, newX, family, obsWeights, id,
                     sigest = NA,
                     sigdf = 3,
                     sigquant = 0.90,
                     k = 2.0,
                     power = 2.0,
                     base = 0.95,
                     binaryOffset = 0.0,
                     ntree = 200,
                     ndpost = 1000,
                     nskip = 100,
                     printevery = 100,
                     keepevery = 1,
                     keeptrainfits = T,
                     usequants = F,
                     numcut = 100,
                     printcutoffs = 0,
                     nthread = 1,
                     keepcall = T,
                     verbose = F,
                     ...) {

  .SL.require("dbarts")

  model =
    dbarts::bart(x.train = X,
                 y.train = Y,
                 # We need to pass newX in directly due to lack of prediction.
                 x.test = newX,
                 sigest = sigest,
                 sigdf = sigdf,
                 sigquant = sigquant,
                 k = k,
                 power = power,
                 base = base,
                 binaryOffset = binaryOffset,
                 weights = obsWeights,
                 ntree = ntree,
                 ndpost = ndpost,
                 nskip = nskip,
                 printevery = printevery,
                 keepevery = keepevery,
                 keeptrainfits = keeptrainfits,
                 usequants = usequants,
                 numcut = numcut,
                 printcutoffs = printcutoffs,
                 nthread = nthread,
                 keepcall = keepcall,
                 verbose = verbose)

  # TODO: there is no predict!
  #pred = predict(model, newdata = newX)
  if (family$family == "gaussian") {
    pred = model$yhat.test.mean
  } else {
    # No mean is provided for binary Y :/
    pred = colMeans(pnorm(model$yhat.test))
  }

  fit = list(object = model)
  class(fit) = c("SL.dbarts")
  out = list(pred = pred, fit = fit)
  return(out)
}
benkeser commented 6 years ago

I'd tried bartMachine in the caret package before, but found it to be really slow. Too slow to want to include in a super learner at least. I admit that I spent 0 time looking at the documentation to figure out if the speed was my fault or inherent to the method. Are there performance benefits to bartMachine over dbarts?

ck37 commented 6 years ago

Well my main issue with dbarts is that it can't predict on new data - any test data has to be supplied during training. (https://github.com/vdorie/dbarts/issues/7)

For bartMachine I've found that as long as multiple cores are used (bartMachine::set_bart_machine_num_cores(x)) it isn't too slow, but I should try dbarts again and compare. Avoiding rJava is a big advantage to dbarts.

benkeser commented 6 years ago

Yeah, that is a bit annoying with dbarts. Good to know about parallel bartMachine.

ecpolley commented 6 years ago

Thanks, is the only change the line for prediction in the binomial case?

   # No mean is provided for binary Y :/
   pred = colMeans(pnorm(model$yhat.test))
benkeser commented 6 years ago

Yep

ecpolley commented 6 years ago

done

ck37 commented 6 years ago

It would be good to add a test for this too.

ecpolley commented 6 years ago

Test if predicted values are outside [0, 1]?

ck37 commented 6 years ago

Yup, ideally (but not mandatory) with a test case that would cause that to happen if we used the original code.

On Thu, Sep 28, 2017 at 2:26 PM, Eric Polley notifications@github.com wrote:

Test if predicted values are outside [0, 1]?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/ecpolley/SuperLearner/issues/102#issuecomment-332968544, or mute the thread https://github.com/notifications/unsubscribe-auth/AADGUmf-N7104ynZEBHOYewj7jAR9xnxks5snA7ogaJpZM4PkkVY .

ecpolley commented 6 years ago

I was thinking we could add a warning in the method.* when family is binomial to check the Z matrix for values outside [0, 1].

ck37 commented 6 years ago

Yeah that would be great too.

On Thu, Sep 28, 2017 at 2:34 PM, Eric Polley notifications@github.com wrote:

I was thinking we could add a warning in the method.* when family is binomial to check the Z matrix for values outside [0, 1].

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/ecpolley/SuperLearner/issues/102#issuecomment-332970548, or mute the thread https://github.com/notifications/unsubscribe-auth/AADGUttyl8Ar4WMdsdAZtR6XTbR4ok17ks5snBD_gaJpZM4PkkVY .

benkeser commented 6 years ago

Hey guys, one more note on SL.dbarts that I just ran into. The bart function doesn't play nicely when X comes in as a data.frame with a single column. I submitted an issue to the dbarts maintainer, but for now I think this little blip can be added into SL.dbarts before the call to bart to fix the error.

  if(dim(X)[2] == 1){
    X <- X[,1]; newX <- newX[,1]
  }