sebp / scikit-survival

Survival analysis built on top of scikit-learn
GNU General Public License v3.0
1.13k stars 215 forks source link

RandomSurvivalForest unusally slow #343

Closed shil3754 closed 1 year ago

shil3754 commented 1 year ago

I attempted to fit a RandomSurvivalForest with 500,000 training instances and 140 features on a machine with 90GB memory. Unfortunately, after hours of waiting, the program ran out of memory and crashed. I wasn't able to see any progress even though the parameter 'verbose' was set to 1.

However, I was able to fit a sklearn.RandomForestRegressor using the same data, with the time of event as the label (censored or not) under the exact same settings. The whole fitting process took less than 1 minute. All common parameters, such as 'n_jobs', were set to be the same, and the only difference was the type of model. In both cases, 'n_jobs' was set to -1 to utilize parallelization.

I am struggling to understand why there is such a significant difference in training time between these two models. Although I expect survival analysis to take a bit longer than usual regression, the difference is quite significant. Unfortunately, my entire training dataset has more than 10 million instances, and it seems rather hopeless to apply RandomSurvivalForest at the moment.

I am wondering if there are any suggestions on how I could speed up the training process.

sebp commented 1 year ago

That's not good.

sksurv calls sklearn to grow the forest, it just overrides the split criterion. It would be interesting to check whether increasing the number of samples or the number of features has a higher impact.

I would suspect that the number of samples could be the problem, because computing the split criterion involves sorting by the time of an event.

sebp commented 1 year ago

I can confirm that it is indeed the sorting operation that is responsible for ~60% of the time it takes to fit a tree.

profiling

solidate commented 1 year ago

I am also having a very similar experience with RandomSurvivalForest as mentioned by @shil3754 . I tested it with 100 records and it worked fine. Then I ran it with 500k records on 96 core and 200GB memory machine. Now, It is taking forever to get any output from RandomSurvivalForest even after adding min_sample_splits, max_depth and variours other hyperparameters

juancq commented 1 year ago

@sebp what is your recommendation in view of this?

sebp commented 1 year ago

@sebp what is your recommendation in view of this?

I haven't investigated what's happening inside the tree builder from sklearn, but my hope is that it would be possible to sort only once and then use the order inside sksurv's split criterion instead of sorting during each call by the tree builder.

tommyfuu commented 1 year ago

having a similar experience with only 33599 samples and 29 features in the training set on a cluster.

sebp commented 1 year ago

FYI, I started looking into this. It should be possible to pre-compute the log-rank statistics for all possible splits in LogrankCriterion.init such that LogrankCriterion.update is just a lookup without requiring to sort again.

sebp commented 1 year ago

The fix-rsf-performance branch contains improved code to grow trees. For me, it reduces growing a single tree for 500 samples and 25 features from 1330ms to 340ms.

Would be great, if you could give it a try too and let me know what you think.

juancq commented 1 year ago

@sebp there's something wrong about the new implementation. Here is the code I used:

import pandas as pd
from sksurv.ensemble import RandomSurvivalForest
from sksurv import datasets as sksurv_datasets

X, y = sksurv_datasets.load_flchain()
Xt = X.drop('chapter', axis=1)
Xt = pd.get_dummies(Xt)

mask = Xt.isna().any(axis=1)
Xt = Xt[~mask]
y = y[~mask]

model = RandomSurvivalForest(n_estimators=100)
model.fit(Xt, y)

With 0.20.0, I get the following running times:

n_estimators run time
10 6 seconds
100 16 seconds
200 16 seconds
500 17 seconds
1000 17 seconds
With 0.21.0, I get the following running times: n_estimators run time
10 5 seconds
100 34 seconds
200 68 seconds
500 97 seconds and then "killed"

System: python: 3.10.8 (main, Dec 5 2022, 10:38:26) [GCC 12.2.0] executable: /home/helloworld/venvs/survival_experiments/bin/python machine: Linux-4.18.0-425.19.2.el8_7.x86_64-x86_64-with-glibc2.28

Python dependencies: sklearn: 1.2.2 pip: 23.1.2 setuptools: 63.2.0 numpy: 1.23.5 scipy: 1.10.1 Cython: 0.29.35 pandas: 1.5.3 matplotlib: 3.7.1 joblib: 1.2.0 threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info: user_api: openmp internal_api: openmp prefix: libgomp filepath: /home/helloworld/venvs/survival_experiments/lib/python3.10/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0 version: None num_threads: 1

   user_api: blas

internal_api: openblas prefix: libopenblas filepath: /home/helloworld/venvs/survival_experiments/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-742d56dc.3.20.so version: 0.3.20 threading_layer: pthreads architecture: SkylakeX num_threads: 1

   user_api: blas

internal_api: openblas prefix: libopenblas filepath: /home/helloworld/venvs/survival_experiments/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-41284840.3.18.so version: 0.3.18 threading_layer: pthreads architecture: SkylakeX num_threads: 1 sksurv: 0.21.0 Traceback (most recent call last): File "/home/helloworld/survival-experiments/version.py", line 3, in import cvxopt; print("cvxopt:", cvxopt.version) ModuleNotFoundError: No module named 'cvxopt'

Traceback (most recent call last): File "/home/helloworld/survival-experiments/version.py", line 4, in import cvxpy; print("cvxpy:", cvxpy.version) ModuleNotFoundError: No module named 'cvxpy'

sksurv: 0.21.0 numexpr: 2.8.4 osqp: 0.6.2

sebp commented 1 year ago

@juancq Could you please post which versions you are using, as described here.

juancq commented 1 year ago

@sebp I updated the comment with the versions I am using, but wait until I post again, I'm trying to reproduce running times and now I'm not getting the same numbers.

juancq commented 1 year ago

@sebp my apologies, the times I posted before for 0.20.0 were the times the script received the killed signal, not the successful running time, which explains the consistent time from 100-1000 estimators.

I do see speed ups with version 0.21.0, as evident below:

With 0.20.0: n_estimators run time
10 5 seconds
100 41 seconds
200 86 seconds
500 217 seconds and then "killed"
1000 217 seconds and then "killed"
With 0.21.0: n_estimators run time
10 5 seconds
100 34 seconds
200 68 seconds
500 97 seconds and then "killed"
oelhammouchi commented 1 year ago

Hi, thanks a lot for your work on this package! I've been experiencing similar performance difficulties despite the fix in 0.21. My training data consists of approx. 50K rows and 27 features (after one-hot encoding). Fitting takes 15-20 min for a single run, so it's very difficult to do cross validation, feature selection, etc. Any idea how I could remedy this? Below is the output of cProfile.

image

sebp commented 1 year ago

@OthmanElHammouchi The current bottleneck is #382

Under the hood, scikit-learn is used to build trees. Unfortunately, this means that each node in every tree contains a survival and cumulative hazard function, which causes overhead due large portions of memory being copied. I don't have a straight-forward solution for this.

oelhammouchi commented 1 year ago

@sebp Ah, I see, thanks for your reply. I'm not familiar with the sklearn implementation, tried looking into it yesterday but it does seem quite intricate.

gpwhs commented 2 months ago

Has there been any movement on this issue? I'm still finding RSF to be incredibly slow on a dataset with ~40 features, 200k samples.

sebp commented 2 months ago

No yet, I'm afraid.

This will require a considerable amount of work, because it would require not relying on scikit-learn to build the tree but implement growing survival trees from scratch.

gpwhs commented 2 months ago

@sebp have you done any scoping at all? I may have some capacity to attack this :)

sebp commented 2 months ago

See #382 (please continue discussion there)

Essentially by relying on scikit-learn's implementation, we have to store the survival and cumulative hazard function in each node of the tree. With many samples you also have many time points and thus large overhead in terms of memory that has to be initialized.