ECMWFCode4Earth / ml_drought

Machine learning to better predict and understand drought. Moving github.com/ml-clim
https://ml-clim.github.io/drought-prediction/
92 stars 18 forks source link

Add shap explanations to models #46

Closed gabrieltseng closed 5 years ago

gabrieltseng commented 5 years ago

Pytorch is in the environment file already, since this will be merged after the neural network stuff

gabrieltseng commented 5 years ago

Example plot:

SMsurf_linear_regression

tommylees112 commented 5 years ago

Example plot:

SMsurf_linear_regression

This is so interesting, particularly that t-3 month lag and t-1 month lag are important whereas t-2months is unimportant.

Am i interpreting that correctly?

gabrieltseng commented 5 years ago

Sort of; these shap values map directly to final VHI (so the sum of shap values is the final VHI). This means that the importance of a feature is better measured by its magnitude than by its actual value.

The first few values are around 0, so not very important. Then, the final 3 values are important. 2 contribute to a higher VHI value, and one contributes to a lower VHI value

tommylees112 commented 5 years ago

mean_variable_importance_linear_regression For the linear regression I plotted the MEAN spatial and temporal shapley values. They offer some weird explanations where the VHI is negatively contributing to the prediction value.

Am I right in interpreting this as 'the mean importance of historical VHI in predicting current VHI at one month lead time is negative. So if the previous VHI is high the predicted VHI is likely to be low'

tommylees112 commented 5 years ago

ts_variable_importance_linear_regression

But then to be fair the time series of these plots look reasonable

tommylees112 commented 5 years ago

I am getting an error with the LinearNetwork explainer:

Preamble:

# train models
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from src.models import LinearRegression, LinearNetwork, Persistence
from src.models.data import DataLoader

data_path = Path('data')
l = LinearRegression(data_path)
l.train()

ln = LinearNetwork(layer_sizes=[100], data_folder=data_path)
ln.train()

The error comes when trying to run the explainer

In [30]: test_arrays_loader = DataLoader(
    ...:     data_path=data_path, batch_file_size=1,
    ...:     shuffle_data=False, mode='test'
    ...: )
    ...: key, val = list(next(iter(test_arrays_loader)).items())[0]
    ...: explanations = ln.explain(val.x)
    ...:
Extracting a sample of the training data
data/features/train/1985_4 returns no values. Skipping
data/features/train/1986_1 returns no values. Skipping
data/features/train/1985_12 returns no values. Skipping
data/features/train/1985_7 returns no values. Skipping
data/features/train/1985_6 returns no values. Skipping
data/features/train/1985_5 returns no values. Skipping
data/features/train/1985_9 returns no values. Skipping
data/features/train/1985_11 returns no values. Skipping
data/features/train/1985_8 returns no values. Skipping
data/features/train/1985_3 returns no values. Skipping
data/features/train/1985_2 returns no values. Skipping
data/features/train/1985_10 returns no values. Skipping
data/features/train/1985_1 returns no values. Skipping
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-30-0ed9ca150eae> in <module>
      4 )
      5 key, val = list(next(iter(test_arrays_loader)).items())[0]
----> 6 explanations = ln.explain(val.x)

~/ml_drought/src/models/linear_network.py in explain(self, x)
     60                 self.model, background_samples)
     61
---> 62         return self.explainer.shap_values(x)
     63
     64     def train(self, num_epochs: int = 1,

~/miniconda3/envs/esowc-drought/lib/python3.7/site-packages/shap/explainers/deep/__init__.py in shap_values(self, X, ranked_outputs, output_rank_order)
    117         were chosen as "top".
    118         """
--> 119         return self.explainer.shap_values(X, ranked_outputs, output_rank_order)

~/miniconda3/envs/esowc-drought/lib/python3.7/site-packages/shap/explainers/deep/deep_pytorch.py in shap_values(self, X, ranked_outputs, output_rank_order)
    156                 tiled_X = [X[l][j:j + 1].repeat(
    157                                    (self.data[l].shape[0],) + tuple([1 for k in range(len(X[l].shape) - 1)])) for l
--> 158                            in range(len(X))]
    159                 joint_x = [torch.cat((tiled_X[l], self.data[l]), dim=0) for l in range(len(X))]
    160                 # run attribution computation graph

~/miniconda3/envs/esowc-drought/lib/python3.7/site-packages/shap/explainers/deep/deep_pytorch.py in <listcomp>(.0)
    155                 # tile the inputs to line up with the background data samples
    156                 tiled_X = [X[l][j:j + 1].repeat(
--> 157                                    (self.data[l].shape[0],) + tuple([1 for k in range(len(X[l].shape) - 1)])) for l
    158                            in range(len(X))]
    159                 joint_x = [torch.cat((tiled_X[l], self.data[l]), dim=0) for l in range(len(X))]

ValueError: operands could not be broadcast together with shape (55,) (3,)
tommylees112 commented 5 years ago
test_arrays_loader = DataLoader(
    data_path=data_path, batch_file_size=1,
    shuffle_data=False, mode='test', to_tensor=True
)
key, val = list(next(iter(test_arrays_loader)).items())[0]
explanations = ln.explain(val.x)

Tried adding in the to_tensor=True argument but then got a whole load of these:

...
Warning: unrecognized nn.Module: LinearBlock
Warning: unrecognized nn.Module: LinearBlock
Warning: unrecognized nn.Module: LinearBlock
Warning: unrecognized nn.Module: LinearBlock
Warning: unrecognized nn.Module: LinearBlock
Warning: unrecognized nn.Module: LinearBlock
Warning: unrecognized nn.Module: LinearBlock
...
gabrieltseng commented 5 years ago

The reason you get all those warnings is because the DeepExplainer is run once per instance, so if you just take the entire batch thats ~30,000 instances.

It works better if you limit it to 3 instances - I've updated the script to reflect that.