ropensci / aorsf

Accelerated Oblique Random Survival Forests
https://docs.ropensci.org/aorsf
Other
54 stars 10 forks source link

Resolves ropensci/aorsf#29 #30

Closed ciaran-evans closed 10 months ago

ciaran-evans commented 10 months ago

This adds a compute_var_reduction function in orsf_oop.cpp to compute the variance reduction after a possible split, allowing for weights. The function is evaluated for correctness and speed in test-compute_var_reduction.R.

Given vectors $y$ (the response values), $w$ (the weights), and $g$ (the group assignments -- 0 = left, 1 = right), the weighted reduction in variance is

\Delta = \frac{1}{\sum \limits_{i=1}^n w_i} \left(\sum \limits_{i=1}^n w_i (y_i - \overline{y}_{node})^2 - \sum \limits_{i=1}^n  w_i (1 - g_i)(y_i - \overline{y}_{left})^2 - \sum \limits_{i=1}^n  w_i g_i(y_i - \overline{y}_{right})^2  \right)

where

\overline{y}_{node} = \frac{ \sum \limits_{i=1}^n w_i y_i}{\sum \limits_{i=1}^n w_i} \hspace{1cm} \overline{y}_{left} = \frac{ \sum \limits_{i=1}^n w_i (1-g_i) y_i}{\sum \limits_{i=1}^n w_i(1-g_i)} \hspace{1cm} \overline{y}_{right} = \frac{ \sum \limits_{i=1}^n w_i g_i y_i}{\sum \limits_{i=1}^n w_i g_i}
bcjaeger commented 10 months ago

This is awesome, thank you @ciaran-evans!

I tinkered with the compute_var_reduction function, and funnily enough I think we can eek out a smidge more speed by using low level operations instead of arma::vec operations (although those are certainly more clear). I'm including a reprex below with a second function compute_var_reduction_2, that aims to replicate the original function's answer with fewer operations. I believe this runs faster. Can you double check to see if it's faster for both of us and not just me?

After checking the two functions, we can finish this PR by doing the following:

  1. move whichever compute_var_reduction we want to stick with to the utility.cpp file, and include an entry in utility.h to match it.
  2. create a compute_var_reduction_exported function in orsf_oop.cpp that calls the compute_var_reduction function
  3. modify the testing file to use compute_var_reduction_exported

I would be happy to do these steps myself. Just let me know if you'd like to take the lead on them.

library(microbenchmark)

Rcpp::cppFunction(
 depends = 'RcppArmadillo',
 code = "
 double compute_var_reduction(arma::vec& y_node,
                             arma::vec& w_node,
                             arma::uvec& g_node){
 arma::vec w_left = w_node % (1 - g_node);
 arma::vec w_right = w_node % g_node;
 double root_mean = sum(y_node % w_node)/sum(w_node);
 double left_mean = sum(y_node % w_left)/sum(w_left);
 double right_mean = sum(y_node % w_right)/sum(w_right);

 return (sum(w_node % pow(y_node - root_mean, 2)) - sum(w_left % pow(y_node - left_mean, 2)) -
          sum(w_right % pow(y_node - right_mean, 2)))/sum(w_node);
}
 "
)

Rcpp::cppFunction(
 depends = 'RcppArmadillo',
 code = "
 double compute_var_reduction_2(arma::vec& y_node,
                               arma::vec& w_node,
                               arma::uvec& g_node){

 double root_mean = 0, left_mean = 0, right_mean = 0;
 double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;

 for(arma::uword i = 0; i < y_node.n_rows; ++i){

  double w_i = w_node[i];
  double y_i = y_node[i] * w_i;

  root_w_sum     += w_i;
  root_mean += y_i;

  if(g_node[i] == 1){
   right_w_sum += w_i;
   right_mean  += y_i;
  } else {
   left_w_sum += w_i;
   left_mean  += y_i;
  }

 }

 root_mean /= root_w_sum;
 left_mean /= left_w_sum;
 right_mean /= right_w_sum;

 double ans = 0;

 for(arma::uword i = 0; i < y_node.n_rows; ++i){

  double w_i = w_node[i];
  double y_i = y_node[i];

  ans += w_i * pow(y_i - root_mean, 2);

  if(g_node[i] == 1){
   ans -= w_i * pow(y_i - right_mean, 2);
  } else {
   ans -= w_i * pow(y_i - left_mean, 2);
  }

 }

 ans /= root_w_sum;

 return(ans);

}
 "
)

n <- 10000

y <- rnorm(n)
w <- runif(n, 0, 2)
g <- rbinom(n, 1, 0.5)

microbenchmark(
 v1 = compute_var_reduction(y, w, g),
 v2 = compute_var_reduction_2(y, w, g),
 times = 5000
)
#> Unit: microseconds
#>  expr    min     lq      mean median      uq      max neval cld
#>    v1 65.001 80.601 111.15622 93.301 109.402 8312.402  5000  a 
#>    v2 72.400 74.302  84.17134 75.301  87.601 9133.501  5000   b

v1 = compute_var_reduction(y, w, g)
v2 = compute_var_reduction_2(y, w, g)

testthat::expect_true(v1 - v2 < .Machine$double.eps)

Created on 2023-10-29 with reprex v2.0.2

bcjaeger commented 10 months ago

One last thing. I think the second for loop in compute_var_reduction_2 could be faster. Do you want to take a stab at making a version 3 that runs faster than version 2?

ciaran-evans commented 10 months ago

Thanks so much @bcjaeger ! Using lower level operations does run faster on my machine too.

I have modified the second for loop, as you suggested, to avoid the if...else... statement. The modified version (v3 in the code below) does appear to run faster. I have just committed a revised version with the v3 code.

library(microbenchmark)

Rcpp::cppFunction(
  depends = 'RcppArmadillo',
  code = "
 double compute_var_reduction(arma::vec& y_node,
                             arma::vec& w_node,
                             arma::uvec& g_node){
 arma::vec w_left = w_node % (1 - g_node);
 arma::vec w_right = w_node % g_node;
 double root_mean = sum(y_node % w_node)/sum(w_node);
 double left_mean = sum(y_node % w_left)/sum(w_left);
 double right_mean = sum(y_node % w_right)/sum(w_right);

 return (sum(w_node % pow(y_node - root_mean, 2)) - sum(w_left % pow(y_node - left_mean, 2)) -
          sum(w_right % pow(y_node - right_mean, 2)))/sum(w_node);
}
 "
)

Rcpp::cppFunction(
  depends = 'RcppArmadillo',
  code = "
 double compute_var_reduction_2(arma::vec& y_node,
                               arma::vec& w_node,
                               arma::uvec& g_node){

 double root_mean = 0, left_mean = 0, right_mean = 0;
 double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;

 for(arma::uword i = 0; i < y_node.n_rows; ++i){

  double w_i = w_node[i];
  double y_i = y_node[i] * w_i;

  root_w_sum     += w_i;
  root_mean += y_i;

  if(g_node[i] == 1){
   right_w_sum += w_i;
   right_mean  += y_i;
  } else {
   left_w_sum += w_i;
   left_mean  += y_i;
  }

 }

 root_mean /= root_w_sum;
 left_mean /= left_w_sum;
 right_mean /= right_w_sum;

 double ans = 0;

 for(arma::uword i = 0; i < y_node.n_rows; ++i){

  double w_i = w_node[i];
  double y_i = y_node[i];

  ans += w_i * pow(y_i - root_mean, 2);

  if(g_node[i] == 1){
   ans -= w_i * pow(y_i - right_mean, 2);
  } else {
   ans -= w_i * pow(y_i - left_mean, 2);
  }

 }

 ans /= root_w_sum;

 return(ans);

}
 "
)

Rcpp::cppFunction(
  depends = 'RcppArmadillo',
  code = "
 double compute_var_reduction_3(arma::vec& y_node,
                                arma::vec& w_node,
                                arma::uvec& g_node){

   double root_mean = 0, left_mean = 0, right_mean = 0;
   double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;

   for(arma::uword i = 0; i < y_node.n_rows; ++i){

     double w_i = w_node[i];
     double y_i = y_node[i] * w_i;

     root_w_sum     += w_i;
     root_mean += y_i;

     if(g_node[i] == 1){
       right_w_sum += w_i;
       right_mean  += y_i;
     } else {
       left_w_sum += w_i;
       left_mean  += y_i;
     }

   }

   root_mean /= root_w_sum;
   left_mean /= left_w_sum;
   right_mean /= right_w_sum;

   double ans = 0;

   for(arma::uword i = 0; i < y_node.n_rows; ++i){

     double w_i = w_node[i];
     double y_i = y_node[i];
     double g_i = g_node[i];
     double obs_mean = g_i*right_mean + (1 - g_i)*left_mean;

     ans += w_i * pow(y_i - root_mean, 2) - w_i * pow(y_i - obs_mean, 2);

   }
   ans /= root_w_sum;
   return(ans);
 }
 "
)

n <- 10000

y <- rnorm(n)
w <- runif(n, 0, 2)
g <- rbinom(n, 1, 0.5)

microbenchmark(
  v1 = compute_var_reduction(y, w, g),
  v2 = compute_var_reduction_2(y, w, g),
  v3 = compute_var_reduction_3(y, w, g),
  times = 5000
)
# Unit: microseconds
# expr    min     lq      mean median     uq       max neval
# v1 80.334 82.376 106.76682 85.542 89.417  9912.667  5000
# v2 42.584 53.167  74.96333 56.459 60.959 10168.834  5000
# v3 34.500 43.167  63.38960 47.000 51.959 13631.459  5000

v1 = compute_var_reduction(y, w, g)
v3 = compute_var_reduction_3(y, w, g)

testthat::expect_true(v1 - v3 < .Machine$double.eps)
bcjaeger commented 10 months ago

Fantastic =] Awesome work. I am looking forward to reviewing!

bcjaeger commented 10 months ago

I'm very happy with this! No changes to request on my end. This function will be called hundreds if not thousands of times for a single forest so every bit of efficiency really counts. =] Thank you SO much. If you'd like to work together on this more, I should be focusing a lot more on regression trees soon.

ciaran-evans commented 10 months ago

Wonderful, thanks! I would love to contribute more; feel free to tag me on any issues where I would be useful