unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.87k stars 850 forks source link

Add support for interpretable outputs in TFTModel #675

Closed dennisbader closed 1 year ago

dennisbader commented 2 years ago

We should think about adding interpretable outputs for TFTModel as done in the original paper.

Mainly:

doaa-altarawy commented 2 years ago

Yes, please, will be very helpful, it's one of the advantages of TFT.

MagMueller commented 1 year ago

@jacobblb As a quick and dirty workaround to have some insights into the variable selection weights until the feature is available one can:

  1. tft.predict(n=Your horizon)
  2. while insert print in darts.models.forecasting.tft_module.py print(encoder_sparse_weights.mean(axis=1)) - for past variables selection (including the past of target) print(decoder_sparse_weights.mean(axis=1)) - for future variables selection
  3. To plot the results - save them in a txt - maybe edit them to be readable by numpy
  4. then do someting like:

    
    def plot_cov_selection(filename, title = "Variable importance"):
    variable_selection = np.loadtxt(filename)
    cov_percentage = variable_selection.mean(axis=0).round(3) * 100
    # your covariates names, for example created with datetime_attribute_timeseries
    cov_names = covs.columns 
    
    # instert target value for past covariates
    if len(cov_names) != len(cov_percentage):
        cov_names = ["Past of Target"] + [cov_names.values[i] for i in range(len(cov_names))] 
        print(cov_names)
    plt.bar(cov_names, cov_percentage)
    plt.title(title)
    plt.xlabel("Variable", fontsize=12 )
    plt.xticks(rotation=45, ha='right', )
    plt.ylabel("Variable importance in %")
    plt.show()

plot_cov_selection('value_selection_encoder.txt', title = "Variable importance") plot_cov_selection('value_selection_decoder.txt', title = "Variable importance")


To get something like

![variable_importance_past](https://user-images.githubusercontent.com/67061560/193576135-06d0e766-2f83-4706-9f0b-01af774b7e5f.png)
![a2c4b711-196c-4053-b413-87663fc23da3](https://user-images.githubusercontent.com/67061560/193576649-8066c70b-0c89-439a-bbc6-f6ffe90d259b.png)
MagMueller commented 1 year ago

@doaa-altarawy Similar for the attention weights:

  1. In tft_model.py in forward pass: print(attn_out_weights.squeeze().sum(axis=1).numpy().tolist(), file=open("attn_out_weights.txt", "a"))

  2. Read file

    import ast
    import numpy as np
    with open('attn_out_weights.txt', 'r') as f:
    attention_matrix = np.array([ast.literal_eval(line) for line in f])
  3. My shape is (439, 24, 192)

    • 439 number of forecasts
    • each with horizon of 24
    • 192 is the input length of TFT (168 in my case) + 24 for horizon, because during multistep forecasting it not just pays attention to the 168 input values, but imagine you forecast 24 steps at once, so the 20. step does not just pay attention to the 168 input values, but also to the 19 values before.
  4. Average over the number of forecasts attention_matrix = attention_matrix.mean(axis=0)[:, :168] and I also just take the 168 past attention values, you could maybe also roll over them to include them at the beginning.

  5. Plot mean attention - similar like in initial TFT paper

    plt.plot(  attention)
    plt.xticks(range(0, attention_matrix.shape[1], attention_matrix.shape[0]) )
    plt.xlabel("Time steps in past")
    plt.ylabel("Attention")
    plt.show()

8b36091b-598c-4531-b765-3fbfb096fafc

  1. or plot heatmap
    
    attention_matrix_avarege =  attention_matrix #.mean(axis=0)

plt.figure(figsize=(20,10))

plt.imshow(attention_matrix_avarege, cmap='hot', interpolation='nearest')

plt.legend()

plt.xticks(range(0, attention_matrix_avarege.shape[1], attention_matrix_avarege.shape[0]) ) plt.xlabel("Time steps in past") plt.ylabel("Horizon") plt.show()



![e67118ec-ae0e-4668-b7d1-64440d90dd3d](https://user-images.githubusercontent.com/67061560/193628851-d51ca852-ba49-4549-9f8c-6fb8451f04d3.png)

7. Now you can inspect how much attention your model pays to your input values (Values) for each forecast timestamp (Query)  to get a kind of seasonal analysis.
hrzn commented 1 year ago

@MagMueller this looks very nice! We are about to introduce a submodule darts.explainability which will contain some explainability features in 0.22.0. Once this is out, would you be willing to help us work on an explainability module for the TFT model? It could be doing things very much along the lines of what you showed here (e.g., producing similar plots, or returning TimeSeries of attention weights over historical time steps).

MagMueller commented 1 year ago

@hrzn Thank you! Yes I would love to help you with the TFT.

Just let me know what you need and how you work together etc... See you:)

hrzn commented 1 year ago

@hrzn Thank you! Yes I would love to help you with the TFT.

Just let me know what you need and how you work together etc... See you:)

@MagMueller Here's perhaps how it could look. You could create a class TFTExplainer in the darts.explainability submodule (we currently have one other explainer called ShapExplainer in there, which currently only works for RegressionModels).

This new class would receive an already-trained TFTModel at creation, and could offer e.g. methods to get/plot variable selection weights. The class should probably inherit from ForecastingModelExplainer, which means that it should also implement the explain() function, whose purpose is to "explain" a particular (set of) forecasts, and return an ExplainabilityResult, which contains TimeSeries representing "explanation values" for given forecasting horizons and components. There might be a way to use the attention weights as the "explanation values" there.

If needed, we can also change/adapt this ForecastingModelExplainer super-class, as currently it's only used in one implementation (ShapExplainer) and may not be 100% appropriate for TFT. We can also add new getter methods to the TFTModel class itself to get its inner parameters cleanly.

So to recap, I think it would be nice if users would be able to run code looking as follows

from darts.models import TFTModel
from darts.explainability import TFTExplainer

my_model = TFTModel(...)
my_mode.fit(...)

explainer = TFTExplainer(my_model, ...)

# look at (or get) variable selection weights
explainer.plot_vsw(...)

# get ExplainabilityResult containing "explanation" TimeSeries, similar to your plot showing Attention over time
expl_result = explainer.explain(my_series, my_past_covs, ...)

explanation = expl_result.get_explanation(component="my_target_dimension_of_interest", horizon=10)
# (not 100% sure we need to keep horizon here ^)

I'm pinging @dennisbader who is most knowledgeable about our TFTModel and @dumjax who implemented the ShapExplainer in case they have inputs to this discussion.

We can discuss the design here on this issue, but also don't hesitate to ping me in private by sending a DM on the Darts gitter (or Discord).

Thanks!

MagMueller commented 1 year ago

@hrzn Sorry, I'm really overloaded right now and haven't taken the time yet. What is the Discord's name?

Cattes commented 1 year ago

Hi, since I am working with a TFT model right now and I am interested in understanding the connections the model has learned, I have created a draft for the class based on @MagMueller's code and @hrzn suggestions about where to put the code and how to structure it.

There are still a few open points that are not clear to me like how to get the variable names or how to better structure the code to make it more robus for the different cases TFT can handle.

Here is the draft PR: https://github.com/Cattes/darts/pull/1/files

hrzn commented 1 year ago

Thanks @Cattes ! I added a couple of high level comments. Do you think it would be possible to (as much as possible) comply with a usage like this:

from darts.models import TFTModel
from darts.explainability import TFTExplainer

my_model = TFTModel(...)
my_mode.fit(...)

explainer = TFTExplainer(my_model, ...)

# look at (or get) variable selection weights
explainer.plot_vsw(...)

# get ExplainabilityResult containing "explanation" TimeSeries, similar to your plot showing Attention over time
expl_result = explainer.explain(my_series, my_past_covs, ...)

explanation = expl_result.get_explanation(component="my_target_dimension_of_interest", horizon=10)
# (not 100% sure we need to keep horizon here ^)

That means we need plot_vsw() and explain() (returning TimeSeries similar to what @MagMueller implemented).

If you feel you have something looking solid enough, you can already open a Draft PR (even if it's not perfect) in Darts repo's master branch. Thanks!

Ping @dumjax @dennisbader

Cattes commented 1 year ago

Hi @hrzn I have added a PR with an implementation of the TFTExplainer with the usage as you suggested: https://github.com/unit8co/darts/pull/1392

I have added an "Explainability" section to the 13-TFT-examples.ipynb notebook to show how to use it: Link to the notebook in my PR

I am looking forward to comments on how to improve it :+1:

madtoinou commented 1 year ago

Solved by #1392.