unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.09k stars 880 forks source link

[QUESTION] How to extract the Shap 'Explanation' object to use with the original Shap package? #2567

Closed DataScientistET closed 3 weeks ago

DataScientistET commented 3 weeks ago

Is there a way to extract the Shap object so that I can use the original SHAP package to plot diagrams which are not yet implemented in darts? For example, I would like to plot image but it requires the 'Explanation' object: image

madtoinou commented 3 weeks ago

Hi @ettan10,

Due to the "temporal" aspect of the forecasts, the Explanation are stored in a dictionary;

from darts.datasets import AirPassengersDataset
from darts.explainability.shap_explainer import ShapExplainer
from darts.models import LinearRegressionModel

series = AirPassengersDataset().load()

model = LinearRegressionModel(lags=12)
model.fit(series[:-36])

shap_explain = ShapExplainer(model)
explanations = shap_explain.summary_plot()
# 1 for the horizon, "#Passengers" for the component
type(explanations[1]["#Passengers"])
>>> shap._explanation.Explanation
DataScientistET commented 3 weeks ago

Thanks @madtoinou. also can I confirm that the get_explanation function always returns an index relative to t0 timestamp regardless of what horizon parameter is passed to it? So if

The actual forecast timestamp those SHAP values belong to in that row is for 2024-01-01 03:00:00 (i.e. for t=2)

DataScientistET commented 3 weeks ago

Also, how do I get the 'Explanation' object for an unseen test set since it is returned by the summary_plot() method.

For example, with your example above,

from darts.datasets import AirPassengersDataset
from darts.explainability.shap_explainer import ShapExplainer
from darts.models import LinearRegressionModel

series = AirPassengersDataset().load()

model = LinearRegressionModel(lags=12)
model.fit(series[:-36])

shap_explain = ShapExplainer(model)
shap_explain.get_explanation(foreground_series = unseen_series)

#how to get the explanation object from here?