unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.9k stars 855 forks source link

TFTExplainer: getting explanations when trained on multiple time series #2489

Open sikhapentyala opened 1 month ago

sikhapentyala commented 1 month ago

I have trained TFT on multiple time series (for example: trained on a retail dataset with 350 time series, each having target, past and future covariates). My understanding was that TFTExplainer would give importance (temporal and variable) based on what it learned from all the time series. To get these, I pass the backgorund_series (and other covariates) that I had used for training to the TFTExplainer. i.e. I pass all 350 series. This gives me the following error

Traceback (most recent call last): File "/main.py", line 90, in results = explainer.explain() File "/lib/python3.9/site-packages/darts/explainability/tft_explainer.py", line 224, in explain values=np.take(attention_heads[idx], horizon_idx, axis=0).T, IndexError: index 30 is out of bounds for axis 0 with size 30

I found that in TFTExplainer.explain(), the size of attention_heads is 30 and not 350.

When I pass only one series as background series, it works (size of attention_heads is 1).

How can I get global explanations for the TFT model when it is trained on multiple time series?

Thank you.

madtoinou commented 4 weeks ago

Hi @sikhapentyala,

I did a bit of investigation and as expected, the size of the attention heads does not depend on the number of series but on the batch size (which is set to 32 by default).

This means (and it's not documented, kind of a bug), that the maximum number of series that can be explained with TFTExplainer is batch_size in the current implementation (which could be improved by iterating over batches). In the meantime, if you want to explain a lot of series series, you will need to make separate calls to TFTEplainer and explain().

import pandas as pd
import numpy as np

from darts import TimeSeries
from darts.models import TFTModel
from darts.explainability.tft_explainer import TFTExplainer
from darts.utils.timeseries_generation import datetime_attribute_timeseries, sine_timeseries

batch_size = 4
model = TFTModel(
    input_chunk_length=10,
    output_chunk_length=2, 
    pl_trainer_kwargs={"accelerator":"cpu"},
    n_epochs=1,
    batch_size=batch_size,
)

possible_starts = [pd.Timestamp(date) for date in ["2000-01-01", "2005-01-01", "2010-01-01"]]
possible_ends = [pd.Timestamp(date) for date in ["2010-01-01", "2015-01-01", "2020-01-01"]]
training_series = [
    sine_timeseries(
        value_frequency=i,
        start=np.random.choice(possible_starts),
        end=np.random.choice(possible_ends),
        freq="M"
    ) for i in range (batch_size+1)
]
future_cov = datetime_attribute_timeseries(pd.date_range(start=pd.Timestamp("1900-01-01"), end=pd.Timestamp("2025-01-01"), freq="M"), "month", cyclic=True)

model.fit(series=training_series, future_covariates=[future_cov]*batch_size+1,)

# works
explainer = TFTExplainer(model,
   background_series=training_series[:batch_size],
background_future_covariates=[future_cov]*batch_size,
)
explanations = explainer.explain()

# does not work
explainer = TFTExplainer(model,
   background_series=training_series[:batch_size+1],
background_future_covariates=[future_cov]*(batch_size+1),
)
explanations = explainer.explain()

# workaround
nb_batches = len(training_series)//batch_size
if len(training_series) % batch_size != 0:
   nb_batches += 1
explanations = []
for batch_idx in range(nb_batches):
   print(batch_size*batch_idx)
   print(batch_size*(batch_idx+1))
   bg_series = training_series[batch_size*batch_idx:batch_size*(batch_idx+1)]
   fut_cov = [future_cov]*len(bg_series)
   explainer = TFTExplainer(
      model,
      background_series=bg_series,
      background_future_covariates=fut_cov,
   )
   explanations.append(explainer.explain())

Just out of curiosity, what do you expect to learn from applying this to all 350 series in you training set?

sikhapentyala commented 3 weeks ago

Thank you.

My understanding is that TFT interpretations output "...the general relationships it has learned" (Section 7 in the paper). For example, Table 3 in the paper gives the variable importance not for a single series but for all series (A single series is the time series for a given store_item pair i.e. 1 entity). Through interpretations, I wanted to see what the model has learned something similar to global interpretations rather than on a batch of examples.

madtoinou commented 3 weeks ago

Nice, thank you for pointing this out.

The interpretability analysis described in Table 3 of the paper is slightly different than what is implemented in TFTExplainer. They look at the weights of the features selection weights, which are stored in self.static_covariates_vsn, self.encoder_vsn and self.decoder_vsn (which are only dependent on the training data, see source) whereas Darts module return the weights of the attention mechanism (which dependent on both the training data & the input passed during prediction).

You should be able to obtain a similar table if you access those attributes and analyze them.