Closed ciaran-evans closed 10 months ago
This is awesome, thank you @ciaran-evans!
I tinkered with the compute_var_reduction
function, and funnily enough I think we can eek out a smidge more speed by using low level operations instead of arma::vec
operations (although those are certainly more clear). I'm including a reprex below with a second function compute_var_reduction_2
, that aims to replicate the original function's answer with fewer operations. I believe this runs faster. Can you double check to see if it's faster for both of us and not just me?
After checking the two functions, we can finish this PR by doing the following:
compute_var_reduction
we want to stick with to the utility.cpp
file, and include an entry in utility.h
to match it.compute_var_reduction_exported
function in orsf_oop.cpp
that calls the compute_var_reduction
functioncompute_var_reduction_exported
I would be happy to do these steps myself. Just let me know if you'd like to take the lead on them.
library(microbenchmark)
Rcpp::cppFunction(
depends = 'RcppArmadillo',
code = "
double compute_var_reduction(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){
arma::vec w_left = w_node % (1 - g_node);
arma::vec w_right = w_node % g_node;
double root_mean = sum(y_node % w_node)/sum(w_node);
double left_mean = sum(y_node % w_left)/sum(w_left);
double right_mean = sum(y_node % w_right)/sum(w_right);
return (sum(w_node % pow(y_node - root_mean, 2)) - sum(w_left % pow(y_node - left_mean, 2)) -
sum(w_right % pow(y_node - right_mean, 2)))/sum(w_node);
}
"
)
Rcpp::cppFunction(
depends = 'RcppArmadillo',
code = "
double compute_var_reduction_2(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){
double root_mean = 0, left_mean = 0, right_mean = 0;
double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i] * w_i;
root_w_sum += w_i;
root_mean += y_i;
if(g_node[i] == 1){
right_w_sum += w_i;
right_mean += y_i;
} else {
left_w_sum += w_i;
left_mean += y_i;
}
}
root_mean /= root_w_sum;
left_mean /= left_w_sum;
right_mean /= right_w_sum;
double ans = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i];
ans += w_i * pow(y_i - root_mean, 2);
if(g_node[i] == 1){
ans -= w_i * pow(y_i - right_mean, 2);
} else {
ans -= w_i * pow(y_i - left_mean, 2);
}
}
ans /= root_w_sum;
return(ans);
}
"
)
n <- 10000
y <- rnorm(n)
w <- runif(n, 0, 2)
g <- rbinom(n, 1, 0.5)
microbenchmark(
v1 = compute_var_reduction(y, w, g),
v2 = compute_var_reduction_2(y, w, g),
times = 5000
)
#> Unit: microseconds
#> expr min lq mean median uq max neval cld
#> v1 65.001 80.601 111.15622 93.301 109.402 8312.402 5000 a
#> v2 72.400 74.302 84.17134 75.301 87.601 9133.501 5000 b
v1 = compute_var_reduction(y, w, g)
v2 = compute_var_reduction_2(y, w, g)
testthat::expect_true(v1 - v2 < .Machine$double.eps)
Created on 2023-10-29 with reprex v2.0.2
One last thing. I think the second for loop in compute_var_reduction_2
could be faster. Do you want to take a stab at making a version 3 that runs faster than version 2?
Thanks so much @bcjaeger ! Using lower level operations does run faster on my machine too.
I have modified the second for loop, as you suggested, to avoid the if...else... statement. The modified version (v3 in the code below) does appear to run faster. I have just committed a revised version with the v3 code.
library(microbenchmark)
Rcpp::cppFunction(
depends = 'RcppArmadillo',
code = "
double compute_var_reduction(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){
arma::vec w_left = w_node % (1 - g_node);
arma::vec w_right = w_node % g_node;
double root_mean = sum(y_node % w_node)/sum(w_node);
double left_mean = sum(y_node % w_left)/sum(w_left);
double right_mean = sum(y_node % w_right)/sum(w_right);
return (sum(w_node % pow(y_node - root_mean, 2)) - sum(w_left % pow(y_node - left_mean, 2)) -
sum(w_right % pow(y_node - right_mean, 2)))/sum(w_node);
}
"
)
Rcpp::cppFunction(
depends = 'RcppArmadillo',
code = "
double compute_var_reduction_2(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){
double root_mean = 0, left_mean = 0, right_mean = 0;
double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i] * w_i;
root_w_sum += w_i;
root_mean += y_i;
if(g_node[i] == 1){
right_w_sum += w_i;
right_mean += y_i;
} else {
left_w_sum += w_i;
left_mean += y_i;
}
}
root_mean /= root_w_sum;
left_mean /= left_w_sum;
right_mean /= right_w_sum;
double ans = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i];
ans += w_i * pow(y_i - root_mean, 2);
if(g_node[i] == 1){
ans -= w_i * pow(y_i - right_mean, 2);
} else {
ans -= w_i * pow(y_i - left_mean, 2);
}
}
ans /= root_w_sum;
return(ans);
}
"
)
Rcpp::cppFunction(
depends = 'RcppArmadillo',
code = "
double compute_var_reduction_3(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){
double root_mean = 0, left_mean = 0, right_mean = 0;
double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i] * w_i;
root_w_sum += w_i;
root_mean += y_i;
if(g_node[i] == 1){
right_w_sum += w_i;
right_mean += y_i;
} else {
left_w_sum += w_i;
left_mean += y_i;
}
}
root_mean /= root_w_sum;
left_mean /= left_w_sum;
right_mean /= right_w_sum;
double ans = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i];
double g_i = g_node[i];
double obs_mean = g_i*right_mean + (1 - g_i)*left_mean;
ans += w_i * pow(y_i - root_mean, 2) - w_i * pow(y_i - obs_mean, 2);
}
ans /= root_w_sum;
return(ans);
}
"
)
n <- 10000
y <- rnorm(n)
w <- runif(n, 0, 2)
g <- rbinom(n, 1, 0.5)
microbenchmark(
v1 = compute_var_reduction(y, w, g),
v2 = compute_var_reduction_2(y, w, g),
v3 = compute_var_reduction_3(y, w, g),
times = 5000
)
# Unit: microseconds
# expr min lq mean median uq max neval
# v1 80.334 82.376 106.76682 85.542 89.417 9912.667 5000
# v2 42.584 53.167 74.96333 56.459 60.959 10168.834 5000
# v3 34.500 43.167 63.38960 47.000 51.959 13631.459 5000
v1 = compute_var_reduction(y, w, g)
v3 = compute_var_reduction_3(y, w, g)
testthat::expect_true(v1 - v3 < .Machine$double.eps)
Fantastic =] Awesome work. I am looking forward to reviewing!
I'm very happy with this! No changes to request on my end. This function will be called hundreds if not thousands of times for a single forest so every bit of efficiency really counts. =] Thank you SO much. If you'd like to work together on this more, I should be focusing a lot more on regression trees soon.
Wonderful, thanks! I would love to contribute more; feel free to tag me on any issues where I would be useful
This adds a
compute_var_reduction
function inorsf_oop.cpp
to compute the variance reduction after a possible split, allowing for weights. The function is evaluated for correctness and speed intest-compute_var_reduction.R
.Given vectors $y$ (the response values), $w$ (the weights), and $g$ (the group assignments -- 0 = left, 1 = right), the weighted reduction in variance is
where