prise6 / aVirtualTwins

Adaptation of Virtual Twins method from Jared Foster
GNU General Public License v3.0
4 stars 2 forks source link

predict function in a VirtualTwins #3

Open LuxiCao opened 6 years ago

LuxiCao commented 6 years ago

Hi Prise6,

Thank you for creating this great R package aVirtualTwins. I'm using it to do simulation to compare those subgroup identification methods. I was wondering if there's a way for me to predict the subgroups of a test sample based on the subgroups I got from the training sample? I checked VT.predict(), it seems that's for the randomforest result not vt.tree()'s.

Thank you in advance!

Regards, Luxi

prise6 commented 6 years ago

Hi LuxiCao,

Thank you for your comment. Actually, i didn't implement this kind of functions. It could be nice ! But, i see virtual twins more like a 2 steps method : Estimate Twins and Explain. So, there are few ways to see virtual twins as a subgroup identification method. And thus, to identify subgroups and simulate the method. Anyway, if you want to use the result of the explain part, you have to estimate the difft of the test set and apply predict.rpart. Check the code below.

I noticed that in object VT.forest.one, i should have done the computation of twins otherwise than only with OOB estimators. I need to correct this.

To me, if a predict, in the way you wrote it, exists, it will produce a 1/0 classification meaning : this patient benefits from the treatment. But, is there a subgroup ? Hard to tell. You have to consider a more general form of the result. Anyway, this kind of function could be useful and more convenient to deal with subgroup results !

Please tell me if i didn't understand your comment. And feel free to contribute :)

Here's some code you can use:

library(aVirtualTwins)

data("sepsis")

# reproducible
set.seed(123)

# create idx train (2/3) / test (1/3)
idx = sample.int(n = nrow(sepsis), size = floor(nrow(sepsis)*1/3))

# create separate dataset
sepsis_train = sepsis[-idx,]
sepsis_test = sepsis[idx,]

# initialize vt object train
vt.obj.train = vt.data(
  dataset         = sepsis_train,
  outcome.field   = "survival",
  treatment.field = "THERAPY",
  interactions    = TRUE
)

# initialize vt object test
vt.obj.test = vt.data(
  dataset         = sepsis_test,
  outcome.field   = "survival",
  treatment.field = "THERAPY",
  interactions    = TRUE
)

# run your model with wrapper function with train dataset
vt.for.train = vt.forest(
  forest.type  = "one",
  vt.data      = vt.obj.train,
  interactions = TRUE,
  ntree        = 500
)

# create your VT.forest.one object with model created above 
# and feed with test dataset
vt.for.test = VT.forest.one(
  vt.object    = vt.obj.test,
  model        = vt.for.train$model,
  interactions = T
)

## # and here you should have done this:
## vt.for.test$run()
## # that's it
## # but, i didn't notice a the time the need
## # to not predict with OOB estimator
## # Instead i did use forest.fold / forest.double in my own (old) simulations
## # So to predict the Twin1 / Twin2 you have to
## # re write the run() function oustide :

vt.for.test$twin1 = as.vector(aVirtualTwins:::VT.predict(
  rfor    = vt.for.test$model,
  newdata = vt.obj.test$getX(interactions = vt.for.test$interactions),
  type    = vt.for.test$vt.object$type
))
if(inherits(vt.for.test, "VT.forest.one")) vt.for.test$vt.object$switchTreatment() #if one forest
vt.for.test$computeTwin2()
if(inherits(vt.for.test, "VT.forest.one")) vt.for.test$vt.object$switchTreatment() #if one forest
vt.for.test$computeDifft()
vt.for.test$vt.object$computeDelta()

# then initialize VT.tree object
vt.tree.test = vt.tree(
  tree.type = "class",
  vt.difft  = vt.for.test, 
  threshold = .2,
  maxdepth  = 2
)

# get results
(vt.subgroups(vt.tree.test))

## # Now if you want to test subgroups results on train dataset
## # check the following code:

vt.tree.train = vt.tree(
  tree.type = "class",
  vt.difft  = vt.for.train, 
  threshold = .2,
  maxdepth  = 2
)

subgroups = vt.subgroups(vt.tree.train)

# now apply subgroups results to test data :
# to me it depend on which "metric" do you want to 
# compare results on ?
# you can use directly the character string coming from 
# subgroups data.frame (from getRules() function)
# my metric was relative risk :

vt.obj.test$getIncidences(subgroups$Subgroup[1])

# or if you want a real predict, use predict.rpart
library(rpart)
vt.tree.test = VT.tree.class(
  vt.difft  = vt.for.test,
  threshold = vt.tree.train$threshold,
  sens      = vt.tree.train$sens
)

pred = predict(vt.tree.train$tree, vt.tree.test$getData(), type = "class")

with(sepsis_test, table(pred, ifelse(PRAPACHE > 26 & AGE > 49.80, 1, 0)))

Edit

After a second thought, maybe in your case, you don't want to re compute difft in the test test. So this following code could be more useful :

library(aVirtualTwins)

data("sepsis")

# reproducible
set.seed(123)

# create idx train (2/3) / test (1/3)
idx = sample.int(n = nrow(sepsis), size = floor(nrow(sepsis)*1/3))

# create separate dataset
sepsis_train = sepsis[-idx,]
sepsis_test = sepsis[idx,]

# initialize vt object train
vt.obj.train = vt.data(
  dataset         = sepsis_train,
  outcome.field   = "survival",
  treatment.field = "THERAPY",
  interactions    = TRUE
)

# initialize vt object test
vt.obj.test = vt.data(
  dataset         = sepsis_test,
  outcome.field   = "survival",
  treatment.field = "THERAPY",
  interactions    = TRUE
)

# run your model with wrapper function with train dataset
vt.for.train = vt.forest(
  forest.type  = "one",
  vt.data      = vt.obj.train,
  interactions = TRUE,
  ntree        = 500
)

vt.tree.train = vt.tree(
  tree.type = "class",
  vt.difft  = vt.for.train, 
  threshold = .2,
  maxdepth  = 2
)

pred = predict(vt.tree.train$tree, vt.obj.test$getX(interactions = F), type = "class")
with(sepsis_test, table(pred, ifelse(PRAPACHE > 26 & AGE > 49.80, 1, 0)))

Regards, prise6.