grf-labs / grf

Generalized Random Forests
https://grf-labs.github.io/grf/
GNU General Public License v3.0
938 stars 250 forks source link

How to calculate the R2 of a local linear forest? #1400

Closed nikosGeography closed 4 months ago

nikosGeography commented 4 months ago

Great package, I have recently discovered the local linear forest (LLF) and I am impressed. Keep up the good work. I'd like to ask how can I compute the R2 (r-squared) of an LLF regression model. Given the example below:

library(grf)

p <- 20
n <- 1000
sigma <- sqrt(20)

mu <- function(x){ log(1 + exp(6 * x)) }
X <- matrix(runif(n * p, -1, 1), nrow = n)
Y <- mu(X[,1]) + sigma * rnorm(n)

X.test <- matrix(runif(n * p, -1, 1), nrow = n)
ticks <- seq(-1, 1, length = n)
X.test[,1] <- ticks
truth <- mu(ticks)

ll.forest <- ll_regression_forest(X, Y, enable.ll.split = TRUE)
preds.llf <- predict(ll.forest, X.test,
                     linear.correction.variables = 1)$predictions

How can I calculate the R2? If this question has already been answered, feel free to close it but please point me to the answer. I searched on the Closed questions but either it hasn't been answered yet or I missed it.

R 4.3.2, RStudio 2023.12.1 Build 402, Windows 11.

erikcs commented 4 months ago

Hi @nikosGeography, from your output above you have responses $Y_i$ and predictions $\hat Y_i$ (preds.llf) that you can plug into the rsquare formula (or any other favorite error metric). If you further along the line are doing causal effect estimation and want an area-under-the-curve metric, you could check out the RATE.

nikosGeography commented 4 months ago

I found a solution which calculates the correlation between the predicted and the target variable and squares the results. Something like r_squared <- (cor(preds.llf, target_variable)^2).