muellerzr / fastinference

A collection of inference modules for fastai2
https://muellerzr.github.io/fastinference
Apache License 2.0
89 stars 16 forks source link

02_shap.interp.ipynb error? #1

Closed ncduy0303 closed 4 years ago

ncduy0303 commented 4 years ago

Hi, I tried running this notebook and got an error at this line:

exp = ShapInterpretation(learn)
exp.decision_plot(class_id=0, row_idx=10)

TypeError                                 Traceback (most recent call last)
<ipython-input-12-907d69fd87c2> in <module>
----> 1 exp.decision_plot(class_id=0, row_idx=10)

<ipython-input-8-27630e1fb799> in decision_plot(self, class_id, row_idx, **kwargs)
     16     def decision_plot(self, class_id=0, row_idx=-1, **kwargs):
     17         "Visualize model decision using cumulative `SHAP` values."
---> 18         shap_vals, exp_val = _get_values(self, class_id)
     19         n_rows = shap_vals.shape[0]
     20         if row_idx == -1:

<ipython-input-10-91e394550211> in _get_values(interp, class_id)
      5     exp_vals = interp.explainer.expected_value
      6     if interp.is_multi_output:
----> 7         (class_name, class_idx) = _get_class_info(interp, class_id)
      8         print(f"Classification model detected, displaying score for the class {class_name}.")
      9         print("(use `class_id` to specify another class)")

<ipython-input-9-96fbf3aee7cf> in _get_class_info(interp, class_id)
      2 def _get_class_info(interp:ShapInterpretation, class_id):
      3     "Returns class name associated with index, or vice-versa"
----> 4     if isinstance(class_id, int): class_idx, class_name = class_id, interp.class_names[class_id]
      5     else: class_idx, class_name = interp.class_names.o2i[class_id], class_id
      6     return (class_name, class_idx)

TypeError: 'NoneType' object is not subscriptable
muellerzr commented 4 years ago

I can't seem to reproduce this. What environment are you in?

ncduy0303 commented 4 years ago

I'm using Jupyter Notebook on Gradient. Here are the versions of the libraries I'm using:

import shap, fastai2, fastcore, fastinference
fastai2.__version__, fastcore.__version__, fastinference.__version__, shap.__version__
('0.0.18', '0.1.18', '0.0.13', '0.35.0')
muellerzr commented 4 years ago

Thanks! I'll try to do that in gradient today and get back with you, for now I can guarantee it works in Colaboratory (this is where I did my testing real quick)

ncduy0303 commented 4 years ago

Thank you!

muellerzr commented 4 years ago

I don't have a gradient account so I'm not sure how to test that actually, but running in my local jupyter environment also doesn't show any issues. The versions are all the same.

ncduy0303 commented 4 years ago

Hi, I still saw the error on my side. But after I traced the error to ShapInterpretation(), I changed self.class_names = learn.dl.vocab if hasattr(learn.dl, 'vocab') else None into self.class_names = learn.dls.vocab if hasattr(learn.dls, 'vocab') else None everything seems to work?

#export
class ShapInterpretation():
    "Base interpereter to use the `SHAP` interpretation library"
    def __init__(self, learn:TabularLearner, test_data=None, link='identity', l1_reg='auto', n_samples=128, **kwargs):
        "Initialize `ShapInterpretation` with a Learner, test_data, link, `n_samples`, `l1_reg`, and optional **kwargs"
        self.model = learn.model
        self.dls = learn.dls
        self.class_names = learn.dl.vocab if hasattr(learn.dl, 'vocab') else None # only defined for classification problems
        self.train_data = pd.merge(learn.dls.cats, learn.dls.conts, left_index=True, right_index=True)
        self.test_data = _prepare_data(learn, test_data, n_samples)
        pred_func = partial(_predict, learn)
        self.explainer = shap.SamplingExplainer(pred_func, self.train_data, **kwargs)
        self.shap_vals = self.explainer.shap_values(self.test_data, l1_reg=l1_reg)
        self.is_multi_output = isinstance(self.shap_vals, list)
muellerzr commented 4 years ago

That fix looks alright to me, would you like to put a PR in? :)