tripartio / ale

Interpretable Machine Learning and Statistical Inference with Accumulated Local Effects (ALE)
https://tripartio.github.io/ale/
GNU General Public License v2.0
3 stars 0 forks source link

longitudinal data #3

Closed snvv closed 1 month ago

snvv commented 3 months ago

Hello, Thank you for the excellent package. I would like to know whether it is possible to use ale with longitudinal data, specifically with the LongituRF or LongCART packages. Thank you.

tripartio commented 3 months ago

Thanks for this question. I am not at all familiar with these packages, so I've been playing around with them. One of the major goals of the ale package is to be model-agnostic, that is, to work with any R object that makes predictions from data. These two packages are really pushing this goal to the limit. So that replies are easier to follow, I will respond with a separate reply for each package, since the answers are different for each.

tripartio commented 3 months ago

I was able to get LongituRF to work. Here, I will work off the examples from the documentation of LongituRF:::predict.longituRF(). (The training is somewhat slow, but I don't know anything about this package, so I'm not going to mess with their example code.)

library(LongituRF)
library(ale)  # Current CRAN version is 0.3.0
library(dplyr)

set.seed(123)
data <- DataLongGenerator(n=20) # Generate the data composed by n=20 individuals.
REEMF <- REEMforest(X=data$X,Y=data$Y,Z=data$Z,id=data$id,time=data$time,mtry=2,ntree=500,sto="BM")
#> [1] "stopped after 19 iterations."

Created on 2024-06-30 with reprex v2.1.0

The trickiest part working with the LongituRF package is that it uses a very particular set of inputs with specific matrices and vectors. However, the ale package requires all the data input in a single dataframe. So, we need to create a single dataframe to hold all that data. (I like the tidyverse, so I use a tibble here, but a dataframe would work just as well if you prefer to stick closer to base R.)

synth_data <- bind_cols(
  Y = data$Y,
  as_tibble(data$X, .name_repair = function(names) {
    paste0("X", seq_along(names))
  }),
  as_tibble(data$Z, .name_repair = function(names) {
    paste0("Z", seq_along(names))
  }),
  id = data$id,
  time = data$time,
)

synth_data
#> # A tibble: 194 × 11
#>         Y     X1       X2     X3     X4      X5     X6    Z1    Z2    id  time
#>     <dbl>  <dbl>    <dbl>  <dbl>  <dbl>   <dbl>  <dbl> <dbl> <dbl> <int> <dbl>
#>  1  0.770 -0.285 -1.19    -1.26   1.16   1.20   -2.59      1 0.920     1     1
#>  2  0.384  1.62  -0.00525 -0.381  0.878  0.397   0.496     1 1.86      1     2
#>  3  8.69   2.27   0.932    0.275  0.877 -0.0457 -1.96      1 0.780     1     3
#>  4  6.97   2.49   1.77     0.556  0.299 -0.271   4.36      1 1.49      1     4
#>  5  9.82   2.49   2.82     1.06   0.200 -0.347  -2.42      1 1.57      1     5
#>  6 16.3    3.03   2.88     1.48  -0.494 -0.575   1.12      1 0.415     1     6
#>  7 15.4    2.79   3.11     1.82  -0.190  0.162   1.86      1 0.113     1     7
#>  8 12.0    2.30   3.14     1.46  -0.329  0.675  -2.27      1 0.663     1     8
#>  9 14.8    2.81   3.14     1.72  -0.685  1.26    2.55      1 1.37      1     9
#> 10 -0.623 -0.538 -1.13    -0.772  1.07   1.29   -2.24      1 1.54      2     1
#> # ℹ 184 more rows

The crucial link between non-standard model types and the ale package is pred_fun argument, which defines a custom prediction function that tells the ale() function how to generate predictions in the format of a numeric vector, one value for each row of data. (In a development version, it also supports matrix predictions, but that is not yet on CRAN.)

Here is the custom prediction function I wrote for LongituRF:

pred_synth_data <- function(object, newdata, type = pred_type) {
  newdata_names <- list(
    X = c('X1', 'X2', 'X3', 'X4', 'X5', 'X6'),
    Z = c('Z1', 'Z2'),
    id = 'id',
    time = 'time'
  )

  LongituRF:::predict.longituRF(
    object,
    X    = newdata[newdata_names$X] |> as.matrix() |> unname(),
    Z    = newdata[newdata_names$Z] |> as.matrix() |> unname(),
    id   = newdata[[newdata_names$id]],  # vector 
    time = newdata[[newdata_names$time]]  # vector
  )
}

I initially wrote a more generic version that was not so closely hardcoded to the specific dataset variable names but I ran into trouble with parallel processing. I might be able to fix that in a future version, but with this custom prediction function that is hardcoded for this specific dataset, parallel processing works fine. Hopefully, you can understand the code enough to customize it for your own data. (One note of caution: the custom predict function is the slowest part of the entire ale algorithm, so try to write it as efficiently as possible. In particular, you should only use base R and avoid code from the tidyverse or other advanced packages within this custom function, or else it will run much slower.)

With these pieces in place, we can call the ale() function to generate ALE data. Note that all the arguments specified here are required for my LongituRF implementation:

ale_REEMF <- ale(
  synth_data,
  REEMF,
  x_cols = c('X1', 'X2', 'X3', 'X4', 'X5', 'X6'),
  y_col = 'Y',
  pred_fun = pred_synth_data
)

# Use the patchwork package for convenient plotting of multiple plots
ale_REEMF$plots |>
  patchwork::wrap_plots()

I won't repeat the code or example here, but LongituRF::MERF() works just as well as the LongituRF::REEMforest() that I demonstrated here.

tripartio commented 3 months ago

LongCART presents a different set of challenges.

First, the 0.3.0 version of the ale package does not support multi-value predictions like LongCART::predict.SurvCART() produces, so I will not try to provide code for that yet. (However, I do have a development version that should support this. So, if you really need this quickly, let me know and I'll try to help you get it working.)

Second, as I said, these packages are completely new to me, so I don't understand how they work very well. Similar to the last example, I created a model from the example from LongCART::predict.LongCART() (again, this is rather slow, but I'm not going to mess with their example):

library(LongCART)
library(ale)

data(ACTG175)
gvars=c("gender", "wtkg", "hemo", "homo", "drugs",
        "karnof", "oprior", "z30", "zprior", "race",
        "str2", "symptom", "treat", "offtrt")
tgvars=c(0, 1, 0, 0, 0,
         1, 0, 0, 0, 0,
         0, 0, 0, 0)
out1<- LongCART(data=ACTG175, patid="pidnum", fixed=cd4~time,
                gvars=gvars, tgvars=tgvars, alpha=0.05,
                minsplit=100, minbucket=50, coef.digits=2)

Created on 2024-06-30 with reprex v2.1.0

However, I am totally thrown off by the fact that the predictions create more lines of data than the original dataset:

pred1<- predict.LongCART(object=out1, newdata=ACTG175, patid="pidnum")

head(ACTG175)
#>         pidnum age    wtkg hemo homo drugs karnof oprior z30 zprior preanti
#> 10056.0  10056  48 89.8128    0    0     0    100      0   0      1       0
#> 10059.0  10059  61 49.4424    0    0     0     90      0   1      1     895
#> 10089.0  10089  45 88.4520    0    1     1     90      0   1      1     707
#> 10093.0  10093  47 85.2768    0    1     0    100      0   1      1    1399
#> 10124.0  10124  43 66.6792    0    1     0    100      0   1      1    1352
#> 10140.0  10140  46 88.9056    0    1     1    100      0   1      1    1181
#>         race gender str2 strat symptom treat offtrt r cens days arms time cd4
#> 10056.0    0      0    0     1       0     1      0 1    0  948    2    0 422
#> 10059.0    0      0    1     3       0     1      0 0    1 1002    3    0 162
#> 10089.0    0      1    1     3       0     1      1 1    0  961    3    0 326
#> 10093.0    0      1    1     3       0     1      0 0    0 1166    3    0 287
#> 10124.0    0      1    1     3       0     0      0 1    0 1090    0    0 504
#> 10140.0    0      1    1     3       0     1      0 1    0 1181    1    0 235

nrow(ACTG175)
#> [1] 6417

head(pred1)
#>   pidnum age    wtkg hemo homo drugs karnof oprior z30 zprior preanti race
#> 1  10056  48 89.8128    0    0     0    100      0   0      1       0    0
#> 2  10056  48 89.8128    0    0     0    100      0   0      1       0    0
#> 3  10056  48 89.8128    0    0     0    100      0   0      1       0    0
#> 4  10056  48 89.8128    0    0     0    100      0   0      1       0    0
#> 5  10056  48 89.8128    0    0     0    100      0   0      1       0    0
#> 6  10056  48 89.8128    0    0     0    100      0   0      1       0    0
#>   gender str2 strat symptom treat offtrt r cens days arms time cd4 node
#> 1      0    0     1       0     1      0 1    0  948    2    0 422   21
#> 2      0    0     1       0     1      0 1    0  948    2    0 422   21
#> 3      0    0     1       0     1      0 1    0  948    2    0 422   21
#> 4      0    0     1       0     1      0 1    0  948    2   20 477   21
#> 5      0    0     1       0     1      0 1    0  948    2   20 477   21
#> 6      0    0     1       0     1      0 1    0  948    2   20 477   21
#>            profile predval
#> 1 407.49+0.27*time  407.49
#> 2 407.49+0.27*time  407.49
#> 3 407.49+0.27*time  407.49
#> 4 407.49+0.27*time  412.89
#> 5 407.49+0.27*time  412.89
#> 6 407.49+0.27*time  412.89

nrow(pred1)
#> [1] 19251

I don't know what's going on here. The ale() function expects there to be only one prediction for each row of input data. If you can help me map one prediction to each input row, then I can proceed with creating an appropriate custom prediction function.

snvv commented 3 months ago

Thank you very much for your assistance!

The LongCART is designed for repeated measurements. In this particular example, each subject is measured three times, resulting in nrow(ACTG175) * 3 = nrow(pred1).

To align the row numbers of ACTG175 and pred1, the following code can be used:

library(dplyr)
pred2 = distinct(pred1, time, predval, pidnum, .keep_all = TRUE)

After applying this, nrow(ACTG175) will match nrow(pred2). To provide further clarity on this behavior, I have reached out to the developer for a detailed explanation.

Best regards, snvv

tripartio commented 3 months ago

OK, I can do that, with one important adjustment. I'll repeat a note I made above in passing which was easy to miss:

(One note of caution: the custom predict function is the slowest part of the entire ale algorithm, so try to write it as efficiently as possible. In particular, you should only use base R and avoid code from the tidyverse or other advanced packages within this custom function, or else it will run much slower.)

So, with that consideration, here it goes for LongCART. First, based on the example from LongCART::predict.LongCART(), I train a model (it's rather slow, but again, I'm not going to try to mess with an example model that I really don't understand):

library(LongCART)
library(ale)
library(dplyr)

data(ACTG175)
gvars=c("gender", "wtkg", "hemo", "homo", "drugs",
        "karnof", "oprior", "z30", "zprior", "race",
        "str2", "symptom", "treat", "offtrt")
tgvars=c(0, 1, 0, 0, 0,
         1, 0, 0, 0, 0,
         0, 0, 0, 0)
out1<- LongCART(data=ACTG175, patid="pidnum", fixed=cd4~time,
                gvars=gvars, tgvars=tgvars, alpha=0.05,
                minsplit=100, minbucket=50, coef.digits=2)

Created on 2024-07-01 with reprex v2.1.0

The prediction functions for LongCART are quite slow, so I will calculate ALE on only a sample of 200 rows:

set.seed(0)
ACTG175_sample <- ACTG175 |>
  na.omit() |>
  slice_sample(n = 200)

Now for the custom predict function for compatibility with ale(). This is simpler than for LongituRF, but it still has a couple of unique issues:

pred_LongCART <- function(
    object, newdata,
    # pred_type is the patid argument to LongCART::predict.LongCART
    type = pred_type
  ) {
  patid <- type

  LongCART::predict.LongCART(object, newdata, patid) |>
    (`[`)(, c('time', 'predval', patid)) |>
    unique() |>
    (`[[`)('predval')
}

Now we can call ale(). However, there are a tiny technical issue which is probably due to a particularity of the ACTG175 dataset. Normally, the x_cols argument should not be required, but it is in this case (see the comment below for the explanation). Other than that, all the other arguments are required for the LongCART package.

Also, sometimes some odd warnings appear, which might be down to the dataset as well. I'm not sure. Anyways, the code runs and the ALE data is created, as the plots show.

ale_LongCART <- ale(
  ACTG175_sample,
  out1,
  # With the ACTG175 dataset, zprior causes a bug in ale() because all values are 1. Note that it cannot be removed from the dataset in the data argument because it was used to train the model. So, simply exclude it from calculations for ale().
  # Since x_cols is specified, also exclude cd4, since it is the y_col.
  x_cols = names(ACTG175_sample) |>
    setdiff(c('zprior', 'cd4')),
  y_col = 'cd4',
  pred_fun = pred_LongCART,
  # For the LongCART package, set pred_type as the patid argument to LongCART::predict.LongCART(). The pred_fun makes the link.
  pred_type = 'pidnum'
)
#> Warning in split.default(delta_pred, boot_ale_x_int): data length is not a
#> multiple of split variable
#> Warning in split.default(delta_pred, boot_ale_x_int): data length is not a
#> multiple of split variable

# Use the patchwork package for convenient plotting of multiple plots
ale_LongCART$plots |>
  patchwork::wrap_plots()
#> Warning: Removed 92 rows containing missing values or values outside the scale range
#> (`geom_line()`).

Hopefully, the code should work when adapted to your own dataset.

Finally, as I mentioned before, the 0.3.0 version of the ale package does not support multi-value predictions like LongCART::predict.SurvCART() produces, so I will not try to provide code for that yet. That should be supported in a version to be released later this year.

Please confirm if I have answered your questions.

snvv commented 3 months ago

I agree, this is a very complicated problem. The global variables (gvars) do not directly influence the response variable ( y ) (in this example, cd4). They are used to partition or cluster the input space. Then, in every terminal node (cluster), the accumulated effects of the fixed effects have to be estimated. See an example below where I modified the example using the wtkg variable as a fixed effect (explanatory variable).

Please note the function ProfilePlot that generates population-level longitudinal profile plots. Therefore, I think the only additional benefit of implementing ale is the confidence interval.

#--------------------------------------------------------------------- #
#   model: cd4~ time + wtkg                                       #
#--------------------------------------------------------------------- #

# Load data and perform initial analysis
data("ACTG175")
ACTG175$time2 <- NULL
summary(ACTG175)
plot(ACTG175$wtkg)

# Define global and treatment variables
gvars <- c("gender", "hemo", "homo", "drugs",
           "karnof", "oprior", "z30", "zprior", "race",
           "str2", "symptom", "treat", "offtrt")
tgvars <- c(0, 0, 0, 0,
            1, 0, 0, 0, 0,
            0, 0, 0, 0)

# Fit the LongCART model
out3 <- LongCART(data = ACTG175, patid = "pidnum", fixed = cd4 ~ time + wtkg,
                 gvars = gvars, tgvars = tgvars, alpha = 0.05,
                 minsplit = 100, minbucket = 50, coef.digits = 2)

# Plot the model
par(mfrow = c(1, 1))
par(xpd = TRUE)
plot(out3, compress = TRUE)
text(out3, use.n = TRUE)

# Generate profile plots
ProfilePlot(x = out3, timevar = "wtkg", timevar.power = c(1)) 
ProfilePlot(x = out3, timevar = "time", timevar.power = c(1))
tripartio commented 3 months ago

Please note the function ProfilePlot that generates population-level longitudinal profile plots. Therefore, I think the only additional benefit of implementing ale is the confidence interval.

Again, I don't know anything about these packages, but I doubt that what you wrote here is accurate, at least when there are interactions in the data.

On one hand, when there are no interactions between variables, most relationship plotting techniques give very similar techniques. On the other hand, ALE can be quite different from most techniques when there are relationships between variables.

I have no idea how LongCART::ProfilePlot() is implemented, but most techniques are "ceteris paribus" plots. That is, they assume "holding all other variables equal". For example, for a model that predicts Y from X1, X2, and X3, the relationship plot for Y against X1 might keep X2 and X3 at their median values. When there is no interaction between X1 and X2 or X1 and X3, this is fine and dandy. But if interactions exist, then holding X2 and X3 at their median values (or any other constant value) is unrealistic and results in a distorted plot that does not represent the original data.

ALE handles this issue by "holding all other variables at their natural values". That is, when it calculates the relationship between Y and X1, for each value of X1, instead of holding X2 and X3 equal to some constant value, it sets them to their natural values in the dataset at each value of X1. This results in plots that more accurately reflect the relationships as they actually occur in the data.

This is the basic theory underlying ALE as described in Apley, Daniel W., and Jingyu Zhu. "Visualizing the effects of predictor variables in black box supervised learning models." Journal of the Royal Statistical Society Series B: Statistical Methodology 82, no. 4 (2020): 1059-1086. You would need to compare different plot types with your own data to see if there is a difference.

Indeed, because of this fundamental advantage of ALE (among others), I hope to add a lot more extensions to ALE functionality beyond plots and confidence intervals. Its more accurate handling of interactions lays a foundation for many more advances.