Open zhanjiezhu opened 4 years ago
Hey @zhanjiezhu, this is actually a really good idea! I haven't looked at SHAP internals, do you think it is something we can convert? Do you want to try?
Hi @interesaaat, I also need to spend some time on SHAP internals to answer your questions, I'll first finish the support for Isolation Forest (already finished but need to add tests), then come back to this. But if you or other people want to work on this, please feel free to do so.
Hey @interesaaat, I'm recently looking into shap internal and I will try to work on this, could you assign the issue to me? Though I will need some more time to think of a proposal, one thing I'm sure is I will need to make a new tree traversal strategy.
We love new tree traversal strategies ;) keep us posted!
@interesaaat a bit updates on progress: I'm still working on this, I've managed to translate the algorithm 1 in https://arxiv.org/pdf/1802.03888.pdf into tensor operations. It met my expectation that I can correctly calculate the shap values and it offers a bit acceleration with GPU but memory explodes very quickly because of the exponential calculations need to be done (I've done some memory profiling and see that at the exact step of advanced indexing with 2^M blows up memory). Now I'm going to translate algorithm 2, which is harder to translate but runs in linear time, will keep you posted when I have more progress!
Looking forward to see the implementation! It would be fantastic if, similarly to GEMM
vs TreeTrav
, we can find cases where the exponential algorithm is faster than the linear one.
Hey @zhanjiezhu, we haven't heard from you in a bit, do you have any update?
Hi @interesaaat, sorry for late reply, I was busy with daily work. Unfortunately I was not able to find a good balance between speed up and memory, I have managed to translate the algorithm 2 (calculated SHAP values are correct) into tensor computations, but I found out it involves an iterative calculation (i.e. the "EXTEND" function as mentioned in algorithm 2 of above paper) which I think is not possible to avoid explicit for loop. And this algorithm 2 only solves part of the TreeExplainer interface (when no background dataset is provided), I've not yet looked into the algorithm when a background dataset is provided ("interventional"). I've currently run out of ideas and didn't have time to work on that in the last 2 months. Hopefully in the next weeks I'll have more time to work on this again or at least provide a more detailed updates and share the things I've already tried. But if there is anyone who want to do this, please feel free to unassign me and take it on.
Don't worry, no one else is working on it right now so it is still all yours. Now that we have torchscript and soon TVM for loops are ok if they are static because they will be unrolled. Anyway keep us posted if you have any further update. If you want us to look at the code you can open a draft PR. Up to you!
@interesaaat: A quick update about using Hummingbird for SHAP, I just realised that for KernelExplainer we can simply pass in the prediction function of the Hummingbird model object. This already offer quite many speed ups (in above case 30 times faster) because the KernelExplainer needs to make predictions on many permutations of instances. So for any model that does not have model-specific explainer AND supported by Hummingbird e.g. SVM, KNN, this way can already offer accelerations easily. For other model-specific explainers e.g. TreeExplainer, we need to look into the inner working and translate the algorithm into tensor computations.
@zhanjiezhu Nvidia RAPIDS recently released a GPUTreeShap variant implementation (https://github.com/rapidsai/gputreeshap, https://arxiv.org/pdf/2010.13972.pdf). Maybe as a starting point, you can check whether you can implement that algorithm purely using tensor ops instead of writing cuda code.
@scnakandala: many thanks for the links, very interesting. For implementing TreeShap using tensor ops, I've tried that before but failed. Will look into their paper and implementation.
Wondering if you have thought of adding support for local interpretability e.g. Tree SHAP? (https://arxiv.org/pdf/1802.03888.pdf).
I've just tested the TreeExplainer in SHAP, applied on random forest regressor. The prediction time v.s. explanation time on same samples and same model is 0.12 second v.s. 448 seconds, it seems there are lots of space for speed up.