marcotcr / lime

Lime: Explaining the predictions of any machine learning classifier
BSD 2-Clause "Simplified" License
11.64k stars 1.81k forks source link

TabularExplainer | exp.as_pyplot_figure() | KeyError: 1 #701

Open Paulyzei opened 2 years ago

Paulyzei commented 2 years ago

Dear Team, @marcotcr,

first of all, thank you for your great work!

I am facing an issue, following your instructions on Submodular Pick. Any help/suggestion would be highly appreciated!

My dataset is about a binary classification problem [0, 1], 1 being the label to be explained. For label '1', explanations do exist. exp.available_labels() returns [1].

So here is how I did it. First, I defined a function for explain_instance:

#define function to set seed in order to ensure consistent results
def explain(data_row, predict_fn, num_features):
  np.random.seed(42)
  return explainer.explain_instance(data_row=a_X_test[i], predict_fn=predict_proba_fn, num_features=num_features, top_labels=None, labels=(1,))

The way I generate the explanation for a specific instance is:

#generate an explanation
i = 3
exp = explain(data_row=a_X_test[i],
              predict_fn='predict_proba_fn',
              num_features=5)
exp.show_in_notebook(show_all=False)

It is followed by instantiation of SubmodularPick object

import warnings
from lime import submodular_pick
sp_obj = submodular_pick.SubmodularPick(explainer, a_X_test, predict_proba_fn, method='full', num_features=10, num_exps_desired=5)

Eventually, I intend to use pyplot to depict the 5 instances, picked by means of SP:

[exp.as_pyplot_figure() for exp in sp_obj.sp_explanations]

In fact, one instance gets depicted. The other 4 however do not get depicted. The detailed log for the received error goes as follows:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In [99], line 1
----> 1 [exp.as_pyplot_figure() for exp in sp_obj.sp_explanations]

Cell In [99], line 1, in <listcomp>(.0)
----> 1 [exp.as_pyplot_figure() for exp in sp_obj.sp_explanations]

File ~/.local/lib/python3.9/site-packages/lime/explanation.py:167, in Explanation.as_pyplot_figure(self, label, **kwargs)
    154 """Returns the explanation as a pyplot figure.
    155 
    156 Will throw an error if you don't have matplotlib installed
   (...)
    164     pyplot figure (barchart).
    165 """
    166 import matplotlib.pyplot as plt
--> 167 exp = self.as_list(label=label, **kwargs)
    168 fig = plt.figure()
    169 vals = [x[1] for x in exp]

File ~/.local/lib/python3.9/site-packages/lime/explanation.py:141, in Explanation.as_list(self, label, **kwargs)
    128 """Returns the explanation as a list.
    129 
    130 Args:
   (...)
    138     given by domain_mapper. Weight is a float.
    139 """
    140 label_to_use = label if self.mode == "classification" else self.dummy_label
--> 141 ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs)
    142 ans = [(x[0], float(x[1])) for x in ans]
    143 return ans

KeyError: 1

I checked sp_obj.sp_explanations to see, if the issue was with SP. However, SP came up with 5 instances, which is as expected.

[<lime.explanation.Explanation at 0x7fe898786400>,
 <lime.explanation.Explanation at 0x7fe897ce6430>,
 <lime.explanation.Explanation at 0x7fe89675faf0>,
 <lime.explanation.Explanation at 0x7fe896764a60>,
 <lime.explanation.Explanation at 0x7fe8987f5400>]

Proceeding with W=pd.DataFrame([dict(this.as_list()) for this in sp_obj.explanations]) results in running into KeyError: 1 again, see:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In [102], line 1
----> 1 W=pd.DataFrame([dict(this.as_list()) for this in sp_obj.explanations])

Cell In [102], line 1, in <listcomp>(.0)
----> 1 W=pd.DataFrame([dict(this.as_list()) for this in sp_obj.explanations])

File ~/.local/lib/python3.9/site-packages/lime/explanation.py:141, in Explanation.as_list(self, label, **kwargs)
    128 """Returns the explanation as a list.
    129 
    130 Args:
   (...)
    138     given by domain_mapper. Weight is a float.
    139 """
    140 label_to_use = label if self.mode == "classification" else self.dummy_label
--> 141 ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs)
    142 ans = [(x[0], float(x[1])) for x in ans]
    143 return ans

KeyError: 1

Thank you for your time and effort in advance! Cheerio, Hannes