tripartio / ale

Interpretable Machine Learning and Statistical Inference with Accumulated Local Effects (ALE)
https://tripartio.github.io/ale/
GNU General Public License v2.0
2 stars 0 forks source link

Custom function for randomForestSRC #8

Open snvv opened 4 days ago

snvv commented 4 days ago

Hello, I am encountering an issue with my custom function for randomForestSRC. Below is a representative example of the code I am working with:

library(randomForestSRC)
library(ale)

mod1 <- rfsrc(mpg ~ ., data = mtcars, importance = TRUE)
p <- predict(mod1, mtcars)
> p
Sample size of test (predict) data: 32
Number of trees grown: 500
Average number of terminal nodes: 3.462
Total number of variables used: 10
Resampling method: swor
Resample size used to grow trees: 20
Analysis: RF-R
Family: regr
R-squared: 0.87741596
Requested performance error: 4.45275514

> p$predicted
 [1] 20.94204 20.87830 24.36764 19.99237 16.63315 19.44042 14.97416 23.90974 23.38696
[10] 19.18424 19.20663 15.69498 15.80511 15.75969 13.68010 13.64524 14.46933 27.40950
[19] 28.03295 28.02586 23.73823 16.04222 16.65875 15.07151 16.25606 27.73644 26.02305
[28] 26.42031 16.17787 20.04792 15.22690 23.39022

My custom prediction function is as follows:

pred_fun <- function(m, x, type = pred_type) {
  as.numeric(predict(m, newdata=x)$predicted)
}

rf_ale <- ale(
  mtcars, mod1,
  y_col = "mpg",
  pred_fun = pred_fun,
  pred_type = 'raw',
  parallel = 14  
)

However, I am receiving the following error:

Error in value[[3L]](cond): 
  There is an issue with the `predict` function `pred_fun` or with the prediction type `pred_type`. Please refer to the help documentation for `ale` on how to create a custom predict function for non-standard models. 

Full error message: 
Error in pred_fun(object = model, newdata = data, type = pred_type): unused arguments (object = model, newdata = data)

I have attempted several variations but have not been able to resolve the issue.

Any assistance would be greatly appreciated.

Thank you!

tripartio commented 4 days ago

The custom predict function must have a signature of function(object, newdata, type). This is the R standard for predict functions. By conforming to the R standard, the ale package works automatically for many model types; only those (like randomForestSRC) that don't conform to the standard need custom predict functions to bring them into conformity.

So, the corrected custom predict function is simply

pred_fun <- function(object, newdata, type = pred_type) {
  as.numeric(predict(object, newdata)$predicted)
}

In fact, the error message pointed to this: Please refer to the help documentation for 'ale' on how to create a custom predict function for non-standard models. If you consult help(ale) you will find

The requirements for this custom function are:
* It must take three required arguments and nothing else:
    * object: a model
    * newdata: a dataframe or compatible table type
    * type: a string; it should usually be specified as type = pred_type 

These argument names are according to the R convention for the generic stats::predict function.

One little note: the pred_type = 'raw' in the ale() call is unnecessary as it does nothing.

Here is the full corrected code:

library(randomForestSRC)
library(ale)

mod1 <- rfsrc(mpg ~ ., data = mtcars, importance = TRUE)

pred_fun <- function(object, newdata, type = pred_type) {
  as.numeric(predict(object, newdata)$predicted)
}

rf_ale <- ale(
  mtcars, mod1,
  y_col = "mpg",
  pred_fun = pred_fun,
  parallel = 14  
)

rf_ale$plots$wt

Created on 2024-09-25 with reprex v2.1.1