siboehm / lleaves

Compiler for LightGBM gradient-boosted trees, based on LLVM. Speeds up prediction by ≥10x.
https://lleaves.readthedocs.io/en/latest/
MIT License
343 stars 29 forks source link

Support for predict_proba / pred_contrib #5

Closed Zahlii closed 3 years ago

Zahlii commented 3 years ago

Hi, For quite some of our applications, we are relying on the predict_proba() as well as predict(pred_contrib=True) / SHAP values. Is there any idea on how complex it would be to add support thereof?

siboehm commented 3 years ago

Will check on the weekend how much effort it is to implement, I expect this to be a common feature demand

siboehm commented 3 years ago

I looked into it:

predict_proba

This function is not part of the LightGBM API (not a function on lightgbm.Booster), but of the Sklearn API, which I don't plan on supporting with lleaves. It outputs the probabilities for each class in multi-class prediction. Multiclass prediction is also not implemented yet, but I'm currently working on it. Once multiclass prediction is implemented lleaves.Model.predict() will return the probabilities for each class, which is exactly what predict_proba would do, too. So basically this feature is coming soon :)

pred_contrib = True:

I don't think I'll implement this as it makes the lleaves backend much more complex. I'm currently looking at vectorizing parts of the trees, which should bring large speedups, but vectorizing the feature contribution calculation is just too much of a complication. It also doesn't seem like a common usecase, non of the tree compilers (treelite, ONNX, Hummingbird) have support for pred_contrib.

siboehm commented 3 years ago

There is now a fully working version of multiclass prediction (including class-probability prediction) on the multiclass_prediction branch. It'll take me a few more days to do some more testing, merge & release, but you can already give it a try.