mmschlk / iXAI

Fast and incremental explanations for online machine learning models. Works best with the river framework.
MIT License
49 stars 2 forks source link

TyperError: unhashable type: 'dict' #90

Open Kirstenml opened 6 months ago

Kirstenml commented 6 months ago

Hello, I have run the sample code and already when initializing the IncrementalPFI the passed loss function (as in the example Accuracy()) is checked with the method _validate_lossfunction. In the __get_loss_function_from_rivermetric method, the update method is called with a dict, but the method update in metrics.Accuracy of the current river version 0.21 no longer expects a dict but a value.

Here is the example code:

from river.metrics import Accuracy
from river.forest import ARFClassifier
from ixai.explainer import IncrementalPFI

stream = Agrawal(classification_function=2)
feature_names = list([x_0 for x_0, _ in stream.take(1)][0].keys())

model = ARFClassifier(n_models=10, max_depth=10, leaf_prediction='mc')

incremental_pfi = IncrementalPFI(
    model_function=model.predict_one,
    loss_function=Accuracy(),
    feature_names=feature_names,
    smoothing_alpha=0.001,
    n_inner_samples=5
)

training_metric = Accuracy()
for (n, (x, y)) in enumerate(stream, start=1):
    training_metric.update(y, model.predict_one(x))   # inference
    incremental_pfi.explain_one(x, y)                 # explaining
    model.learn_one(x, y)                             # learning
    if n % 1000 == 0:
        print(f"{n}: Accuracy: {training_metric.get():.3f}, PFI: {incremental_pfi.importance_values}")

The error is:

Traceback (most recent call last):
  File "/media/user/main.py", line 10, in <module>
    incremental_pfi = IncrementalPFI(
  File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/ixai/explainer/pfi.py", line 68, in __init__
    super(IncrementalPFI, self).__init__(
  File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/ixai/explainer/base.py", line 71, in __init__
    self._loss_function = validate_loss_function(loss_function)
  File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/ixai/utils/validators/loss.py", line 30, in validate_loss_function
    validated_loss_function = _get_loss_function_from_river_metric(river_metric=loss_function)
  File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/ixai/utils/validators/loss.py", line 18, in _get_loss_function_from_river_metric
    _ = river_metric.update(y_true=0, y_pred={0: 0}).revert(y_true=0, y_pred={0: 0})
  File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/river/metrics/base.py", line 93, in update
    self.cm.update(
  File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/river/metrics/confusion.py", line 67, in update
    self._update(y_true, y_pred, w)
  File "/media/user/.virtualenvs/pyenv1/lib/python3.10/site-packages/river/metrics/confusion.py", line 75, in _update
    self.data[y_true][y_pred] += w
TypeError: unhashable type: 'dict'
mmschlk commented 6 months ago

Oh no ... I haven't looked into the code for some time. That's the problem when implementing against a changing API. Is there a quick workaround for you?

I suspect these problems to come up more and more. However at the moment I am not capable of maintining this to fit with the current version of river.

Kirstenml commented 6 months ago

Unfortunately, it currently only works by downgrading the River version to 0.14.