scikit-learn-contrib / hiclass

A python library for hierarchical classification compatible with scikit-learn
BSD 3-Clause "New" or "Revised" License
113 stars 20 forks source link

An error occurred when using the Explainer of the last version #127

Open RamSnoussi opened 1 month ago

RamSnoussi commented 1 month ago

hi @mirand863 What's the problem here? How can I correct this error?

from hiclass import LocalClassifierPerParentNode, Explainer
from sklearn.ensemble import RandomForestClassifier
import numpy as np
import shap

X_train = np.array([
    [40.7,  1. ,  1. ,  2. ,  5. ,  2. ,  1. ,  5. , 34.3],
    [39.2,  0. ,  2. ,  4. ,  1. ,  3. ,  1. ,  2. , 34.1],
    [40.6,  0. ,  3. ,  1. ,  4. ,  5. ,  0. ,  6. , 27.7],
    [36.5,  0. ,  3. ,  1. ,  2. ,  2. ,  0. ,  2. , 39.9],
])

Y_train = np.array([
    ['Gastrointestinal', 'Norovirus', ''],
    ['Respiratory', 'Covid', ''],
    ['Allergy', 'External', 'Bee Allergy'],
    ['Respiratory', 'Cold', ''],
])

X_test = np.array([[35.5,  0. ,  1. ,  1. ,  3. ,  3. ,  0. ,  2. , 37.5]])

classifier = LocalClassifierPerParentNode(local_classifier=RandomForestClassifier())
classifier.fit(X_train, Y_train)
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test)
print(explanations)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 26
     24 classifier.fit(X_train, Y_train)
     25 explainer = Explainer(classifier, data=X_train, mode="tree")
---> 26 explanations = explainer.explain(X_test)
     27 print(explanations)

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/hiclass/Explainer.py:124, in Explainer.explain(self, X)
    117 check_array(X)
    119 if (
    120     isinstance(self.hierarchical_model, LocalClassifierPerParentNode)
    121     or isinstance(self.hierarchical_model, LocalClassifierPerLevel)
    122     or isinstance(self.hierarchical_model, LocalClassifierPerNode)
    123 ):
--> 124     return self._explain_with_xr(X)
    125 else:
    126     raise ValueError(f"Invalid model: {self.hierarchical_model}.")

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/hiclass/Explainer.py:142, in Explainer._explain_with_xr(self, X)
    128 def _explain_with_xr(self, X):
    129     """
    130     Generate SHAP values for each node using the SHAP package.
    131 
   (...)
    140         An xarray Dataset consisting of SHAP values for each sample.
    141     """
--> 142     explanations = Parallel(n_jobs=self.n_jobs, backend="threading")(
    143         delayed(self._calculate_shap_values)(sample.reshape(1, -1)) for sample in X
    144     )
    146     dataset = xr.concat(explanations, dim="sample")
    147     return dataset

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/joblib/parallel.py:1918, in Parallel.__call__(self, iterable)
   1916     output = self._get_sequential_output(iterable)
   1917     next(output)
-> 1918     return output if self.return_generator else list(output)
   1920 # Let's create an ID that uniquely identifies the current call. If the
   1921 # call is interrupted early and that the same instance is immediately
   1922 # re-used, this id will be used to prevent workers that were
   1923 # concurrently finalizing a task from the previous call to run the
   1924 # callback.
   1925 with self._lock:

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/joblib/parallel.py:1847, in Parallel._get_sequential_output(self, iterable)
   1845 self.n_dispatched_batches += 1
   1846 self.n_dispatched_tasks += 1
-> 1847 res = func(*args, **kwargs)
   1848 self.n_completed_tasks += 1
   1849 self.print_progress()

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/hiclass/Explainer.py:275, in Explainer._calculate_shap_values(self, X)
    273     traversed_nodes = self._get_traversed_nodes_lcpl(X)[0]
    274 elif isinstance(self.hierarchical_model, LocalClassifierPerParentNode):
--> 275     traversed_nodes = self._get_traversed_nodes_lcppn(X)[0]
    276 elif isinstance(self.hierarchical_model, LocalClassifierPerNode):
    277     traversed_nodes = self._get_traversed_nodes_lcpn(X)[0]

File ~/anaconda3/envs/hiclass/lib/python3.8/site-packages/hiclass/Explainer.py:170, in Explainer._get_traversed_nodes_lcppn(self, samples)
    164 traversals = np.empty(
    165     (samples.shape[0], self.hierarchical_model.max_levels_),
    166     dtype=self.hierarchical_model.dtype_,
    167 )
    169 # Initialize first element as root node
--> 170 traversals[:, 0] = self.hierarchical_model.root_
    172 # For subsequent nodes, calculate mask and find predictions
    173 for level in range(1, traversals.shape[1]):

ValueError: invalid literal for int() with base 10: 'hiclass::root'
RamSnoussi commented 1 month ago

Hi, Because encoding the type of y_ become from string to int64 (line 223 HierarchicalClassifier.py). Then the type of traversals (line 164 Explainer.py) is int64. However self.hierarchicalmodel.root (line 170 Explainer.py) is string. Then the file Explainer.py should be updated.

RamSnoussi commented 1 month ago

Hi @mirand863, Do you have any suggestions about this issue?

mirand863 commented 1 month ago

Hi @RamSnoussi,

Can you please explain a bit how you are using the explainer? Are you using with encoded labels related to the other issue we were discussing previously so the root should also be an integer? Is that correct?

RamSnoussi commented 1 month ago

Hi @mirand863, I'm using the hiclass version v4.10.0 (https://github.com/scikit-learn-contrib/hiclass). Look please the example above and the error generated (traversals[:, 0] = self.hierarchicalmodel.root in line 170 Explainer.py). This error is due because traversals[:,0] is an Integer and self.hierarchicalmodel.root is a string. You modified the version of hiclass by adding the encoder (in file HierarchicalClassifier.py) but Explainer.py shoud be updated too.

RamSnoussi commented 1 month ago

This attacked file (HierarchicalClassifier.py) is in my execution environment when you used Encoder in line 221. However, the encoder has been deleted in github's version (https://github.com/scikit-learn-contrib/hiclass/blob/main/hiclass/HierarchicalClassifier.py) but why? which released version can I use? HierarchicalClassifier.txt

mirand863 commented 1 month ago

This attacked file (HierarchicalClassifier.py) is in my execution environment when you used Encoder in line 221. However, the encoder has been deleted in github's version (https://github.com/scikit-learn-contrib/hiclass/blob/main/hiclass/HierarchicalClassifier.py) but why? which released version can I use? HierarchicalClassifier.txt

I see now. The encoder has not been released yet, but is only available in this branch called cuml https://github.com/scikit-learn-contrib/hiclass/compare/main...cuml

I was just testing it out and never actually released. I also did not need the explainer in my use case, so I did not update that file. I might be able to do it in the next days if it is important for your use case. I can add tests and try to make it run without bugs with a proper release. :)

RamSnoussi commented 4 weeks ago

Hi @mirand863, Have you updated the explainer's file that corresponds to encoder add ?Thanks.