tripartio / ale

Interpretable Machine Learning and Statistical Inference with Accumulated Local Effects (ALE)
https://tripartio.github.io/ale/
GNU General Public License v2.0
3 stars 0 forks source link

`pred_fun` function for `randomForest` and `xrf` #5

Closed snvv closed 1 month ago

snvv commented 2 months ago

Hello, I'm encountering issues with the predict function in several models. Below are the examples where I am facing difficulties:

# Example 1: Random Forest model
library("randomForest")
data("Boston", package = "MASS")
rf <- randomForest(medv ~ ., data = Boston, ntree = 10)

# Example 2: xRF model
library(xrf)
m <- xrf(Petal.Length ~ ., iris,
         xgb_control = list(nrounds = 2, max_depth = 2),
         family = 'gaussian')

I've tried various prediction functions (pred_fun), but all attempts have failed.

Could you please provide an example of how to correctly use the predict function with these models?

Thank you!

tripartio commented 2 months ago

Hello,

Perhaps part of the challenge you are facing here is that the ale package tries to automatically detect predict functions as much as possible. In fact, you do not need custom predict functions for basic operation in either of these cases.

I plan to eventually write a vignette on how to write custom predict functions (if they are needed). Maybe I could use this response as a first draft.

The first thing to do is to just try to run ale(). It often works automatically.

library(ale)

# Example 1: Random Forest model
library("randomForest")
#> randomForest 4.7-1.1
#> Type rfNews() to see new features/changes/bug fixes.
data("Boston", package = "MASS")
rf <- randomForest(medv ~ ., data = Boston, ntree = 10)

rf_ale <- ale(Boston, rf)
# sample plot
rf_ale$plots$lstat

Created on 2024-09-02 with reprex v2.1.1

So, in the case of randomForest, default options work just fine. No custom predict function or anything else is needed. This is often the case.

library(ale)
library(xrf)
m <- xrf(Petal.Length ~ ., iris,
         xgb_control = list(nrounds = 2, max_depth = 2),
         family = 'gaussian')

xrf_ale <- ale(iris, m)
#> Error in validate_y_col(y_col = y_col, data = data, model = model): This model seems to be non-standard, so y_col must be provided.

Created on 2024-09-02 with reprex v2.1.1

But that doesn't work with xrf. It is not such a common model, so ale() is unable to automatically figure out the outcome column. So, as the error message indicates, let's try to provide this with the y_col argument:

# Provide the y_col
xrf_ale <- ale(iris, m, y_col = 'Petal.Length')

# sample plot
xrf_ale$plots$Species

Created on 2024-09-02 with reprex v2.1.1

So, these were rather easy cases where the pred_fun argument is not needed at all. So, the lesson here is to just try your model with default options; ale() often works.

I will alert you, though, that the most common complication is when a model's predict function requires a type argument to specify what kind of prediction you want. In that case, enter it as the pred_type argument in ale().

If you have trouble where these defaults and simple arguments don't work, then post another comment on this same issue (if it is with xrf or randomForest) or raise a completely new issue if it is with a different package.

snvv commented 2 months ago

Thank you a lot for your help so far.

I have encountered another issue—hopefully, the last one—related to interactions. When I run the following code:

iter <- ale_ixn(
  Boston, rf,
  parallel = 2
)

I receive the error:

  2% | Calculating ALE interactions ETA: 26s                       
Progress interrupted by purrr_error_indexed condition: ℹ In index: 1.
Caused by error in `map()`:
ℹ In index: 3.
Caused by error in `calc_ale_ixn()`:
! x2 must be numeric or integer. Only x1 can be of a different datatype.
Error in (function (.x, .f, ..., .progress = FALSE) : ℹ In index: 1.                               
Caused by error in `map()`:
ℹ In index: 3.
Caused by error in `calc_ale_ixn()`:
! x2 must be numeric or integer. Only x1 can be of a different datatype.

I receive the same error when using the xrf model trained on the Boston dataset.

Could you please assist me in resolving this issue?

Thank you!

tripartio commented 2 months ago

@snvv To keep things focused, could you please post this different issue as a new issue? I will close this one since it seems to be resolved.

And it doesn't need to be your last issue. :-)