interpretml / interpret

Fit interpretable models. Explain blackbox machine learning.
https://interpret.ml/docs
MIT License
6.04k stars 715 forks source link

Question about Inference Time #494

Open mtl-tony opened 5 months ago

mtl-tony commented 5 months ago

Has there been any papers or studies conducted on inference time? I recall seeing one of the advantages being inference time due to the tabular lookup format of the model resulting in fewer FLOPS which makes sense but was wondering if there was any precise results regarding these time comparisons. An example would be comparing lets say standard EBM vs GLM vs GBM vs Transformers on the same dataset.

paulbkoch commented 5 months ago

HI @mtl-tony -- Inference time speed is a complicated topic, and the real answer is that it depends. Our paper had comparisons on a few datasets. https://arxiv.org/pdf/1909.09223.pdf

If you're looking for ways to improve inference time speed, by far the biggest impact you can make is to make predictions in batches. If you make predictions one sample at a time, then by far the most amount of time is spent executing python code that extracts the raw data from numpy/pandas/scipy and the minority is spent on the actual inference part. This limitation applies to all the popular tree boosting packages that I’ve looked at. As a consequence, most models (assuming good implementations) tend to perform in the same ballpark if you're making predictions one at a time.

For batched predictions there can be quite a bit of variability between model types. EBMs will tend to be faster in scenarios where there are fewer bins, less features, and the features are more important. If there are a lot of unimportant features then other tree algorithms might ignore most of the features, and thus not pay much inference cost for them. This paper describes a method of pruning features from an existing EBM, with one of the goals being improved inference time: https://arxiv.org/pdf/2311.07452.pdf

EBMs also tend to do comparatively better when given Fortran-ordered numpy data, or data in pandas DataFrames (which are typically Fortran-ordered).

If you have categorical data, if you can use pandas then use a CategoricalDtype. It's around 50 times faster than using object arrays. If you need to stick with numpy then for the fastest speeds you’d want to convert your categories to floats yourself and pass in a feature type to indicate it’s a categorical.

Happy to go into more detail if you have a more specific scenario.