timmens / causal-forest

Implements the Causal Forest algorithm formulated in Athey and Wager (2018).
MIT License
65 stars 12 forks source link

Optimize _find_optimal_split function. #6

Open timmens opened 4 years ago

timmens commented 4 years ago

Problem: Right now the function _find_optimal_split is very inefficient. In the inner loop over splitting_points I compute means and sums in every iteration, even though I could update an initial value.

Solution: Implement dynamic updating algorithm that finds best splitting point for a given feature index.

timmens commented 4 years ago

What has been done: The commits (aca31b1ea3f76964) and (a2f504f9ccfbab8cd9) improve the speed of the inner loop (over observations) by a big margin. In the first commit I changed most np.sum() and np.mean() calls for a dynamic sum extension. In the second commit I swapped pd.DataFrame data storage for the fast np.array and now simply convert the end result to a pd.DataFrame.

What still needs to be done:

  1. The code need to be checked for correctness against the old implementation and unit tests have to be written.
  2. To make the code even faster it has be profiled while numba is disabled, since this allows to check what function calls make _find_optimal_split slow. Current profiling has shown that the function _find_optimal_split is still the only major concern.