siboehm / lleaves

Compiler for LightGBM gradient-boosted trees, based on LLVM. Speeds up prediction by ≥10x.
https://lleaves.readthedocs.io/en/latest/
MIT License
333 stars 28 forks source link

`Model.compile` fails for `LGBMRegressor` with `boosting_type='rf'` #58

Closed trendelkampschroer closed 8 months ago

trendelkampschroer commented 8 months ago
...
 File "/Users/btschroer/.miniforge3/envs/selection/lib/python3.10/site-packages/lleaves/lleaves.py", line 118, in compile
    module = compiler.compile_to_module(
  File "/Users/btschroer/.miniforge3/envs/selection/lib/python3.10/site-packages/lleaves/compiler/tree_compiler.py", line 17, in compile_to_module
    forest = parse_to_ast(file_path)
  File "/Users/btschroer/.miniforge3/envs/selection/lib/python3.10/site-packages/lleaves/compiler/ast/parser.py", line 96, in parse_to_ast
    scanned_model = scan_model_file(model_path)
  File "/Users/btschroer/.miniforge3/envs/selection/lib/python3.10/site-packages/lleaves/compiler/ast/scanner.py", line 35, in scan_model_file
    res["general_info"] = _scan_block(general_info_block, INPUT_SCAN_KEYS)
  File "/Users/btschroer/.miniforge3/envs/selection/lib/python3.10/site-packages/lleaves/compiler/ast/scanner.py", line 109, in _scan_block
    scanned_key, scanned_value = line.split("=")
ValueError: not enough values to unpack (expected 2, got 1)

The booster.txt file for a random forest regressor contains the following "offending" line in the "header"

 tree
version=v4
num_class=1
num_tree_per_iteration=1
label_index=0
max_feature_idx=4
objective=regression
average_output  # -> line.split("=")  fails here
feature_names=Column_0 Column_1 Column_2 Column_3 Column_4
...

Deleting the line from the booster.txt file "solves" the problem, but I am not sure if that would be correct, since average output indicates that output from trees should be averaged (and not added as for the gradient boosting case).

Is it possible to use lleaves for random forest models? Id be glad to receive any suggestions/help how to handle this.

Thanks a lot for your time and for creating lleaves and making it available.

siboehm commented 8 months ago

Huh interesting, thanks for the issue. I must've missed this feature. It shouldn't be too hard to implement, I can do it over the weekend. There's 3 steps

  1. Train a new LGBM model that uses this average output feature, and store the model.txt. Use that model.txt to write a test that compares LightGBM output to lleaves output
  2. Adjust the parser so it doesn't trip over this line but parses it
  3. Adjust the backend to have it compute the average.

I can do 2.) and 3.), but could you help me with 1.)? If you can train a very simple tree that uses this feature and send me the model.txt, that'd help a lot. Basically I'm planning to add it to this integration test: https://github.com/siboehm/lleaves/blob/master/tests/test_tree_output.py

siboehm commented 8 months ago

Ok see #59. Can you test that branch and see if it fixes your issue? Particularly, please make sure the output between lleaves and your original LGBM model is (approximately) equal

trendelkampschroer commented 8 months ago

@siboehm Thanks a lot for the quick response and the PR with the fix.

On my side the following test still fails (adapted from your classifier test)

@pytest.mark.parametrize("num_trees", [
    pytest.param(34, id="this is still ok"),
    pytest.param(35, id="this fails", marks=pytest.mark.xfail(strict=True))

])
def test_rf_regression(tmpdir, num_trees):
    n_samples = 10_000
    X, y = make_regression(n_samples=n_samples, n_features=5, noise=10.0)

    params = {
        "objective": "regression",
        "n_jobs": 1,
        "boosting_type": "rf",
        "subsample_freq": 1,
        "subsample": 0.9,
        "colsample_bytree": 0.9,
        "num_leaves": 25,
        "n_estimators": num_trees,
        "min_child_samples": 100,
        "verbose": 0
    }

    model = LGBMRegressor(**params).fit(X, y)
    model_file = str(tmpdir / "model.txt")
    model.booster_.save_model(model_file)

    lgbm = lightgbm.Booster(model_file=model_file)
    llvm = lleaves.Model(model_file=model_file)
    llvm.compile()

    print(lgbm.predict(X, n_jobs=2) / llvm.predict(X, n_jobs=2))
    np.testing.assert_almost_equal(
         lgbm.predict(X, n_jobs=2), llvm.predict(X, n_jobs=2), decimal=10
    )
    assert lgbm.num_model_per_iteration() == llvm.num_model_per_iteration()

So there seems to be something going on that won't allow to compile a random forest with more than 34 trees.

And now looking into the code 34 is exactly the default for the variable fblocksize for "chunking" the trees during execution (as far as I understand this with very limited knowledge of the IR generating code). So it seems that num_trees in the fdiv call is not the total number of trees but rather the size of the chunks of trees that get executed together.

Thanks a lot for trying to fix this. My understanding of the internals (and performance implications of changes) is far too rough to attempt to remedy this on my own. It'd be super helpful if you could chime in, it's probably just a matter of applying the division to the final reduction instead of per chunk of trees.

You can also use the test to profile the compile call when there are "many" trees. In which case it can take a long time.


    with cProfile.Profile() as pr:
        llvm.compile()
        stats = pstats.Stats(pr)
    stats.sort_stats("tottime")
    stats.print_stats(20)

This is for 1000 trees

5022588 function calls (4606478 primitive calls) in 9.326 seconds

   Ordered by: internal time
   List reduced from 290 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       37    7.454    0.201    7.454    0.201 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/binding/ffi.py:188(__call__)
118144/65575    0.156    0.000    0.654    0.000 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/ir/_utils.py:44(__str__)
489894/176453    0.146    0.000    0.700    0.000 {method 'format' of 'str' objects}
   100777    0.130    0.000    0.222    0.000 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/ir/_utils.py:16(register)
   100777    0.097    0.000    0.358    0.000 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/ir/values.py:537(__init__)
   187530    0.072    0.000    0.230    0.000 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/ir/_utils.py:54(get_reference)
    28191    0.067    0.000    0.258    0.000 miniforge3/envs/lleaves/lib/python3.11/site-packages/llvmlite/ir/values.py:1154(__init__)
24000/1000    0.060    0.000    0.854    0.001 .../git/lleaves/lleaves/compiler/codegen/codegen.py:128(_gen_decision_node)
    65570    0.059    0.000    0.609    0.000

but the vast majority of time is spend on the llvmlite calls, i.e. it seems that the overhead for parsing the tree and building the IM is relatively small compared to the actual costs for compiling.

Is it actually possible to shift most of the compile time to a machine-unspecific intermediate that can be loaded when deploying the model, so that the "final" compilation step is faster and generates the actual machine-specific binary? - Sorry if this is a somewhat "naive" question, I have almost no understanding of code optimisation carried out by (modern) compiler stacks, like e.g. LLVM.

siboehm commented 8 months ago

Oh yeah good catch, the denominator of the averaging is wrong. That won't be hard to fix. Good debugging!

As to the compile time: Yeah the compiletime really sucks, this is the biggest problem with lleaves currently. I have some thoughts as to why and how to make it faster here: https://github.com/siboehm/lleaves/issues/48. The real fix (carefully timing & enabling optimization passes, plus writing my own inliner / generating already inlined code) is quite a large refactor and I probably will not get around to this anytime soon. I've been toying with rewriting the whole backend using LLVM's C++ interface, but this is really a personal project for me so there are no deadlines ;)

It sounds like (I'm guessing here) you're using lleaves on some cluster of machines / inside of cloud virtual machines were you're never really sure what the precise underlying CPU will be. Some ideas that may work for you:

  1. (not hacky) Obviously try to use the caching as much as you can. You can look at this code and use it to cache the compiled binary across machines (if the CPU features are the same, the output binary will be the same, so you can reuse the cachefile). With how slow the compilation is even having the cachefile on S3 would be a huge speedgain.
  2. (little bit hacky) You could write a script that just loads the cached file and tries to run a prediction on some arbitrary data. If it fails with SIGILL you know you'll have to recompile. The lleaves code is not actually that architecture specific (there's no vectorization etc) so as long as your machines are not extremely different (e.g. you don't have some modern CPUs with AVX but also some old ones without) this will work.
siboehm commented 8 months ago

I'll make a release sometime later. Feel free to open a new issue for the compiletime stuff / discuss on the other issue, but I'm closing this one. Thanks for filing the issue and helping out, this was a smooth process! :)

trendelkampschroer commented 8 months ago

Thanks a lot @siboehm for the quick fix and your response. I will summarise my benchmarks and create a new issue tracking the compile time performance. Looking forward to trying out the fix in a new release.