crflynn / skgrf

scikit-learn compatible Python bindings for grf (generalized random forests) C++ random forest library
https://skgrf.readthedocs.io/en/stable/
GNU General Public License v3.0
31 stars 7 forks source link

use grf forest pointer for predictions #40

Closed crflynn closed 3 years ago

crflynn commented 3 years ago

Currently skgrf executes predictions by passing the serialized forest dict. The prediction functions deserialize the forest on every prediction, which is expensive. This PR modifies the prediction functions so that they accept pointers to already deserialized forest objects.

In order to do this, we create a Cython GRFForest class which serves as a container for the forest pointer. We also add an _ensure_ptr method, which creates this forest wrapper for a trained estimator if it does not already exist.

Since this wrapped pointer cannot be serialized, we also override the __getstate__ and __setstate__ methods to ensure we can pickle and unpickle each class. On unpickling, we also call _ensure_ptr so that the first call to predict after unpickling is not cold-started.

The result of this is that predict calls are faster. Here is a simple comparison with sample data from my experimenting with this:

import timeit
from skgrf.ensemble import GRFRegressor
from sklearn.datasets import load_boston

boston_X, boston_y = load_boston(return_X_y=True)

gfr = GRFRegressor()
gfr.fit(boston_X, boston_y)

ptr = timeit.timeit(lambda : gfr.predict_ptr(boston_X), number=100)
ser = timeit.timeit(lambda : gfr.predict(boston_X), number=100)

print(ptr)
0.252420819000001
print(ser)
1.2568888119999997

Changes