crflynn / skranger

scikit-learn compatible Python bindings for ranger C++ random forest library
https://skranger.readthedocs.io/en/stable/
GNU General Public License v3.0
52 stars 7 forks source link

'sample_fraction' runtime error with one prediction #61

Closed dalekube closed 3 years ago

dalekube commented 3 years ago

The use of 'sample_fraction' in the fitting of the RangerForestRegressor causes the error terminate called after throwing an instance of 'std::runtime_error' what(): sample_fraction too small, no observations sampled. when the trained model is used for a prediction, rfr.predict() for only one observation.

The error does not occur when more than one observation is inputted into rfr.predict(). Also, this error does not occur when sample_fraction=[1] and making a prediction with one observation.

crflynn commented 3 years ago

I believe sample fraction is only used on fitting, so we can just pass sample_fraction=[1] on predict. I'll take a look at this.

dalekube commented 3 years ago

I believe sample fraction is only used on fitting, so we can just pass sample_fraction=[1] on predict. I'll take a look at this.

I do not specify the sample_fraction in the predict call. It looks like predict() inherits the parameter from fit() even though it's not relevant for predictions.

crflynn commented 3 years ago

Right, I should have been more specific. The entrypoint into the ranger bindings is the same for fit/predict, and right now we pass the specified sample_fraction for predicting too: https://github.com/crflynn/skranger/blob/d409bd8fff9260fe5a30f9837965b5fd05418250/skranger/ensemble/ranger_forest_regressor.py#L357

When creating the forest the C++ does a check thats failing: https://github.com/imbs-hl/ranger/blob/e8b05f47892bb4968c4e6057f68b35bcd0b52225/src/Forest.cpp#L256, even on predict when it doesn't have to.

dalekube commented 3 years ago

Problem solved with the 0.3.2 release. Confirmed in my application that predict() ignores the sample_fraction parameter according to the new design. Thanks!