ropensci / aorsf

Accelerated Oblique Random Survival Forests
https://docs.ropensci.org/aorsf
Other
55 stars 10 forks source link

Pass additional arguments to customized functions in identifying linear combinations of predictors (suggestion + issue) #67

Closed AbubakerSuliman closed 3 days ago

AbubakerSuliman commented 2 weeks ago

Dear Prof @bcjaeger thank you so much for such a great package.

First, Here is my two cents on improving the speed of method='net'.

In penalized_cph.R I'd suggest looping through the unique values instead of the complete list. Fit a custom penalized Cox regression using unique values on pbc_orsf results in 10% to 30% speed reduction.

indxs = c(1, which(diff(fit$df)>=1)+1)
  for(i in indxs){
    if(fit$df[i] >= target_df || i == tail(indxs, 1)){
      return(matrix(fit$beta[, i, drop=TRUE], ncol = 1))
    }
  }

Second, I'm exploring different methods to create linear combinations of predictors; however, I can't pass an additional argument (e.g. target_df) or access it from the parent environment in the case of a custom function. The following function will throw an error. I would appreciate any ideas on how to solve this issue.

f_net <- function(x_node, y_node, w_node, add_arg) {
  add_arg
  colnames(y_node) <- c('time', 'status')
  colnames(x_node) <- paste("x", seq(ncol(x_node)), sep = '')

  data <- as.data.frame(cbind(y_node, x_node))

  if(nrow(data) <= 10)
    return(matrix(runif(ncol(x_node)), ncol = 1))

  suppressWarnings(
    fit <- try(
      glmnet::glmnet(x = x_node,
                     y = survival::Surv(data$time, data$status),
                     weights = w_node,
                     alpha = 0.5,
                     family = "cox"),
      silent = TRUE
    )
  )

  if(aorsf:::is_error(fit)){
    return(matrix(runif(ncol(x_node)), nrow=ncol(x_node), ncol=1))
  }

  indxs = c(1, which(diff(fit$df)>=1)+1)
  for(i in indxs){
    if(fit$df[i] >= 5 || i == tail(indxs, 1)){
      return(matrix(fit$beta[, i, drop=TRUE], ncol = 1))
    }
  }

}
bcjaeger commented 2 weeks ago

Thank you! If I could get the time, I'd want to write routines in C++ that mimic glmnet so that we wouldn't have to call the R function.

For sending customized input values to the R function in aorsf, why not just define multiple functions that each use different hard coded values of target_df? It's logistically complicated to allow general objects to be passed into these functions b/c we'd have to know what to declare those objects as in C++

AbubakerSuliman commented 1 week ago

Thanks for the prompt response. This is well noted and received.

I have a new issue with using a custom function in mlr3 as control_type can only accept one of p_fct(levels = c("fast", "cph", "net"), default = "fast", tags = "train") (R/learner_aorsf_surv_aorsf.R). Looking at the code, it seems possible, but I feel you have a technical/design reason not to include it. I'd appreciate any help/ideas with this issue.

Update 1: Here is a first try to modify the aorsf learner in mlr3extralearners after quick readings

ps = ps(
...
control_type = p_fct(levels = c("fast", "cph", "net", "custom"), default = "fast", tags = "train"),
control_custom_fun = p_uty(custom_check = function(x) checkmate::checkFunction(x, nargs = 3), 
                                              depends = control_type == "custom", tags = "train"),
...
)
 ...
,
 "custom" = {
   aorsf::orsf_control_survival(
     method = pv$control_custom_fun
   )
 }
)
# these parameters are used to organize the control arguments
# above but are not used directly by aorsf::orsf(), so:
pv = remove_named(pv, c("control_type",
                       ...,
                       "control_custom_fun"))

It seems to be working fine, but I have an issue with importance()

> learner$importance()
x8 x7 x6 x5 x4 x3 x2 x1
 0  0  0  0  0  0  0  0 
bcjaeger commented 1 week ago

Here is a first try to modify the aorsf learner in mlr3extralearners after quick readings

This is awesome! Thank you for writing this code. I wonder if we could request the custom method be added to the aorsf learner in mlr3extralearners once we've figured out the importance issue. Do you also get importance values of 0 for this aorsf model when you fit it using aorsf::orsf?

AbubakerSuliman commented 1 week ago

I wonder if we could request the custom method be added to the aorsf learner in mlr3extralearners once we've figured out the importance issue

Why not? It would allow benchmarking of anything.

Do you also get importance values of 0 for this aorsf model when you fit it using aorsf::orsf?

Yes, interestingly it works fine with importance = 'negate'/'permute'

bcjaeger commented 6 days ago

Thank you! Could you clarify the second item for me? Did aorsf::orsf() give you all 0's for importance values, or did it work fine (i.e., giving non-zero important values).

For the PR, would you like to take the lead by initiating an issue on mlr3extralearners? If you'd like, I could do this, but I am slowed down by other obligations and I also want to make sure you get credited for the awesome work you've done.

AbubakerSuliman commented 6 days ago

Thank you! Could you clarify the second item for me? Did aorsf::orsf() give you all 0's for importance values, or did it work fine (i.e., giving non-zero important values).

aorsf::orsf() with a custom function works fine when I calculate importance using "negate" or "permute"; however, it fails when importance uses "anova". Here a MWE

library(aorsf)
f_rando <- function(x_node, y_node, w_node){
  matrix(runif(ncol(x_node)), ncol=1) 
}
fit_rando_anova <- orsf(pbc_orsf,
                  Surv(time, status) ~ . - id,
                  control = orsf_control_survival(method = f_rando),
                  importance = "anova",
                  tree_seeds = 329)
fit_rando_negate <- orsf(pbc_orsf,
                        Surv(time, status) ~ . - id,
                        control = orsf_control_survival(method = f_rando),
                        importance = "negate",
                        tree_seeds = 329)
fit_rando_permute <- orsf(pbc_orsf,
                         Surv(time, status) ~ . - id,
                         control = orsf_control_survival(method = f_rando),
                         importance = "permute",
                         tree_seeds = 329)

fit_rando_anova$importance
   stage  protime platelet     trig      ast alk.phos   copper  albumin     chol     bili 
       0        0        0        0        0        0        0        0        0        0 
   edema  spiders   hepato  ascites      sex      age      trt 
       0        0        0        0        0        0        0 

fit_rando_negate$importance
        bili       copper      protime        stage          age          ast          sex 
 0.061821928  0.061275078  0.045544739  0.039651143  0.039253153  0.027136941  0.019333503 
      hepato         chol      spiders         trig     alk.phos      ascites      albumin 
 0.015253945  0.014237792  0.010395895  0.010026580  0.009735672  0.009062175  0.007236837 
       edema          trt     platelet 
 0.005684531  0.003000386 -0.002974909 

fit_rando_permute$importance
       copper          bili       protime           age         stage           ast          chol 
 0.0288431287  0.0275947301  0.0235931010  0.0216206097  0.0174380106  0.0136475128  0.0076317414 
       hepato       spiders       albumin          trig       ascites         edema      alk.phos 
 0.0060957151  0.0060563554  0.0052816286  0.0050055876  0.0049497113  0.0028515817  0.0019844534 
          sex           trt      platelet 
-0.0005690087 -0.0020921897 -0.0030092387 

Regarding the PR, many thanks for the kind words. Sure, I will start the PR soon.

bcjaeger commented 5 days ago

Ahh, I see, that makes sense. ANOVA importance requires calculation of p-values so I didn't even attempt to do anova importance when a custom function is used to get linear combinations of predictors. I think perhaps aorsf::orsf() should throw an error if someone uses a custom function with ANOVA importance to prevent this confusing result - do you think that would be helpful?

AbubakerSuliman commented 3 days ago

do you think that would be helpful?

Of course, I wondered why you don't allow ANOVA for custom methods, and then I read the following from aorsf Github main page "ANOVA is very efficient computationally, but may not be as effective as permutation or negation in terms of selecting signal over noise variables."
So yes, an error message would be enough.

I opened a PR for the custom method here.

Finally, feel free to close this issue.

bcjaeger commented 3 days ago

Thank you! I should update that main page to also mention we can only compute anova importance if the linear combination method allows us to compute p-values for the variables that are being combined