siboehm / lleaves

Compiler for LightGBM gradient-boosted trees, based on LLVM. Speeds up prediction by ≥10x.
https://lleaves.readthedocs.io/en/latest/
MIT License
354 stars 29 forks source link

`Model.compile` is slow for a random forest with a large number of trees #60

Open trendelkampschroer opened 10 months ago

trendelkampschroer commented 10 months ago

I am copying over parts of my comments from the discussion started in #58

Consider the following example in which a random forest with a large number of trees is compiled via lleaves.Model.compile

import lleaves
from lightgbm import LGBMRegressor
from sklearn.datasets import make_regression

if __name__ == "__main__":
    n_samples = 10_000
    X, y = make_regression(n_samples=n_samples, n_features=5, noise=10.0)

    num_trees = 1_000
    params = {
        "objective": "regression",
        "n_jobs": 1,
        "boosting_type": "rf",
        "subsample_freq": 1,
        "subsample": 0.9,
        "colsample_bytree": 0.9,
        "num_leaves": 25,
        "n_estimators": num_trees,
        "min_child_samples": 100,
        "verbose": 0
    }

    model = LGBMRegressor(**params).fit(X, y)
    model_file = str(tmpdir / "model.txt")
    model.booster_.save_model(model_file)

    lgbm = lightgbm.Booster(model_file=model_file)
    llvm = lleaves.Model(model_file=model_file)
    with cProfile.Profile() as pr:
        llvm.compile()
        stats = pstats.Stats(pr)
    stats.sort_stats("cumtime")
    stats.print_stats(20)
5022588 function calls (4606478 primitive calls) in 9.326 seconds

   Ordered by: internal time
   List reduced from 290 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       37    7.454    0.201    7.454    0.201 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/binding/ffi.py:188(__call__)
118144/65575    0.156    0.000    0.654    0.000 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/ir/_utils.py:44(__str__)
489894/176453    0.146    0.000    0.700    0.000 {method 'format' of 'str' objects}
   100777    0.130    0.000    0.222    0.000 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/ir/_utils.py:16(register)
   100777    0.097    0.000    0.358    0.000 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/ir/values.py:537(__init__)
   187530    0.072    0.000    0.230    0.000 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/ir/_utils.py:54(get_reference)
    28191    0.067    0.000    0.258    0.000 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/ir/values.py:1154(__init__)
24000/1000    0.060    0.000    0.854    0.001 .../git/lleaves/lleaves/compiler/codegen/codegen.py:128(_gen_decision_node)
    65570    0.059    0.000    0.609    0.000

One can see that the vast majority of time is spend on the llvmlite calls, i.e. it seems that the overhead for parsing the tree and building the IM is relatively small compared to the actual costs for compiling. It would be really cool if the compile time could be further reduced - the compiled artifact is machine specific so that a safe deployment on e.g. a kubernetes cluster potentially requires a costly re-compilation whenever the model is deployed. It would be super helpful if this time could be further reduced.

In addition I observe the following timings for different values of num_trees:

num_trees time in seconds
10 0.06
100 0.8
1000 9.3
10000 167.9

Notably, the compilation step seems to not benefit from multiple cores on my machine, so if there is a way to just throw more CPU resources at it, I'd be happy to explore it.

siboehm commented 10 months ago

See also #48. Some comments:

CaypoH commented 4 months ago

I waited for compile step for about 10 minutes and then interrupted a process. Speedup (if there is any) does not worth it if my application would take more than 10 minutes to start. Tested on MacBook Pro M1.

Soontosh commented 1 month ago

Similar situation here. I'm running lleaves on a TPU with 334GB RAM, but it still takes forever to compile, even 1 hour is not enough (I haven't been able to successfully compile yet). For context I am using LightGBM Classifier with the following parameters:

I have tried using finline as both True and False.

I don't know what is happening on a low level, but I see that RAM usage increases at a steady rate to ~66GB before stagnating, and nothing happens.