microsoft / hummingbird

Hummingbird compiles trained ML models into tensor computation for faster inference.
MIT License
3.36k stars 278 forks source link

Support for speeding up Tree SHAP? #180

Open zhanjiezhu opened 4 years ago

zhanjiezhu commented 4 years ago

Wondering if you have thought of adding support for local interpretability e.g. Tree SHAP? (https://arxiv.org/pdf/1802.03888.pdf).

I've just tested the TreeExplainer in SHAP, applied on random forest regressor. The prediction time v.s. explanation time on same samples and same model is 0.12 second v.s. 448 seconds, it seems there are lots of space for speed up.

interesaaat commented 4 years ago

Hey @zhanjiezhu, this is actually a really good idea! I haven't looked at SHAP internals, do you think it is something we can convert? Do you want to try?

zhanjiezhu commented 4 years ago

Hi @interesaaat, I also need to spend some time on SHAP internals to answer your questions, I'll first finish the support for Isolation Forest (already finished but need to add tests), then come back to this. But if you or other people want to work on this, please feel free to do so.

zhanjiezhu commented 4 years ago

Hey @interesaaat, I'm recently looking into shap internal and I will try to work on this, could you assign the issue to me? Though I will need some more time to think of a proposal, one thing I'm sure is I will need to make a new tree traversal strategy.

interesaaat commented 4 years ago

We love new tree traversal strategies ;) keep us posted!

zhanjiezhu commented 4 years ago

@interesaaat a bit updates on progress: I'm still working on this, I've managed to translate the algorithm 1 in https://arxiv.org/pdf/1802.03888.pdf into tensor operations. It met my expectation that I can correctly calculate the shap values and it offers a bit acceleration with GPU but memory explodes very quickly because of the exponential calculations need to be done (I've done some memory profiling and see that at the exact step of advanced indexing with 2^M blows up memory). Now I'm going to translate algorithm 2, which is harder to translate but runs in linear time, will keep you posted when I have more progress!

interesaaat commented 4 years ago

Looking forward to see the implementation! It would be fantastic if, similarly to GEMM vs TreeTrav, we can find cases where the exponential algorithm is faster than the linear one.

interesaaat commented 4 years ago

Hey @zhanjiezhu, we haven't heard from you in a bit, do you have any update?

zhanjiezhu commented 4 years ago

Hi @interesaaat, sorry for late reply, I was busy with daily work. Unfortunately I was not able to find a good balance between speed up and memory, I have managed to translate the algorithm 2 (calculated SHAP values are correct) into tensor computations, but I found out it involves an iterative calculation (i.e. the "EXTEND" function as mentioned in algorithm 2 of above paper) which I think is not possible to avoid explicit for loop. And this algorithm 2 only solves part of the TreeExplainer interface (when no background dataset is provided), I've not yet looked into the algorithm when a background dataset is provided ("interventional"). I've currently run out of ideas and didn't have time to work on that in the last 2 months. Hopefully in the next weeks I'll have more time to work on this again or at least provide a more detailed updates and share the things I've already tried. But if there is anyone who want to do this, please feel free to unassign me and take it on.

interesaaat commented 4 years ago

Don't worry, no one else is working on it right now so it is still all yours. Now that we have torchscript and soon TVM for loops are ok if they are static because they will be unrolled. Anyway keep us posted if you have any further update. If you want us to look at the code you can open a draft PR. Up to you!

zhanjiezhu commented 4 years ago

Screenshot 2020-11-13 at 17 41 05

@interesaaat: A quick update about using Hummingbird for SHAP, I just realised that for KernelExplainer we can simply pass in the prediction function of the Hummingbird model object. This already offer quite many speed ups (in above case 30 times faster) because the KernelExplainer needs to make predictions on many permutations of instances. So for any model that does not have model-specific explainer AND supported by Hummingbird e.g. SVM, KNN, this way can already offer accelerations easily. For other model-specific explainers e.g. TreeExplainer, we need to look into the inner working and translate the algorithm into tensor computations.

scnakandala commented 4 years ago

@zhanjiezhu Nvidia RAPIDS recently released a GPUTreeShap variant implementation (https://github.com/rapidsai/gputreeshap, https://arxiv.org/pdf/2010.13972.pdf). Maybe as a starting point, you can check whether you can implement that algorithm purely using tensor ops instead of writing cuda code.

zhanjiezhu commented 4 years ago

@scnakandala: many thanks for the links, very interesting. For implementing TreeShap using tensor ops, I've tried that before but failed. Will look into their paper and implementation.