feedzai / timeshap

TimeSHAP explains Recurrent Neural Network predictions.
Other
162 stars 30 forks source link

Timeshap for RNN based recommender system #27

Closed ashrimal2 closed 1 year ago

ashrimal2 commented 1 year ago

I am working on a sequential recommender system using GRU4Rec. Can I use timeshap for such a problem? Do you by chance have a demo for the same?

Thanks

JoaoPBSousa commented 1 year ago

Hi @ashrimal2,

Could you please provide more context to your issue, so we can provide a better informed answer? In any case, TimeSHAP can provide explanations for a specific output in your sequence. If in your use case you provide a score for a set of output "items", TimeSHAP should be able to explain each of these output scores individually.

All of our available demos are published on the package in github. In case you test TimeSHAP in your use case, we would appreciate the feedback, and in case you have an example we would be happy to add it to our demos.

We hope this was helpful, don't hesitate to contact us if you have any further questions or comments.

ashrimal2 commented 1 year ago

Hello @JoaoPBSousa ,

We have developed a sequential movie recommendation system (GRU4Rec) using the ml-100k dataset. Suppose a user has already watched movies with IDs 1, 2, and 3. The system suggests that the next movie they should watch is movie ID 4. My objective is to identify which specific movie among the sequence of 1, 2, and 3 had the greatest influence on the recommendation. We are currently evaluating if timeshap would be right for our use. Please share your thoughts.

JoaoPBSousa commented 1 year ago

Hello @ashrimal2,

From what you describe, it is possible to obtain those explanations using TimeSHAP event-level explanations when explaining the model score for movie ID 4.

If you have any further questions or comments, don't hesitate to contact us.

ashrimal2 commented 1 year ago

Thank you. I do have some questions regarding the example you shared (AReM). I am trying to use that example as a reference for my problem, but I am not sure what my model_features, label_features, and sequence_id_feat should be for the ml-100k dataset. Any suggestions? Dataset just has 4 columns user_id, item_id, timestamp and rating. @JoaoPBSousa

JoaoPBSousa commented 1 year ago

Hi @ashrimal2,

In TimeSHAP, model_features corresponds to the columns of the dataset that your model uses to make predictions, label_features correspond to the label columns, and sequence_id_feat relates to the column that allows you to identify individual sequences, this column being native to the dataset, or generated by you.

Regarding your issue with the ml-100k dataset, I think that this dataset has features for both the users and the movies, maybe some processing of the dataset is required.

If you have any further questions or comments, don't hesitate to contact us.

ashrimal2 commented 1 year ago

Hello @JoaoPBSousa,

I was wondering if you could kindly provide me with additional examples of timeshap (if available) notebooks similar to the AReM notebook. I am eager to gain a more comprehensive understanding of the practical applications of timeshap and believe that examining more examples would be helpful in achieving this goal.

Thank you for your time and assistance in this matter.

JoaoPBSousa commented 1 year ago

Hi @ashrimal2,

We currently do not have any other examples of TimeSHAP besides the ones on the repository. In case you develop a new example that you can share, we are happy to add it to the repository. If you have any questions regarding TimeSHAP's application or inner workings, feel free to share them so we can try to answer them and help you!

Additionally, you can find a video explanations of TimeSHAP here. Hope this can aid in some of your questions.

ashrimal2 commented 1 year ago

Hi @JoaoPBSousa , I am getting AssertionError when I try to run the following code:

pruning_dict = {
    "tol": 0.025,
}
coal_plot_data, coal_prun_idx = local_pruning(
    f_hs,
    formatted_pos_x_data,
    pruning_dict,
    average_event_array,
    positive_sequence_id,
    "user_id",
    False,
)
# coal_prun_idx is in negative terms
pruning_idx = formatted_pos_x_data.shape[1] + coal_prun_idx
pruning_plot = plot_temp_coalition_pruning(coal_plot_data, coal_prun_idx, plot_limit=40)
pruning_plot

Trace Logs:

15 Mar 20:03    INFO  phi = [0.         1.38757396]
15 Mar 20:03    INFO  phi = [0.         1.87885606]
15 Mar 20:03    INFO  phi = [ 0.         -2.72752589]
15 Mar 20:03    INFO  phi = [ 0.         -2.75893843]
15 Mar 20:03    INFO  phi = [0.         0.40875238]
15 Mar 20:03    INFO  phi = [0.         3.78581321]
15 Mar 20:03    INFO  phi = [ 0.         -0.70323297]
15 Mar 20:03    INFO  phi = [0.         1.23742276]
15 Mar 20:03    INFO  phi = [ 0.         -3.66250378]
15 Mar 20:03    INFO  phi = [0.        2.3613646]
15 Mar 20:03    INFO  phi = [0.         2.80285966]
15 Mar 20:03    INFO  phi = [0.         6.08620739]
15 Mar 20:03    INFO  phi = [ 0.         -0.04296517]
15 Mar 20:03    INFO  phi = [0.         3.67728138]
15 Mar 20:03    INFO  phi = [0.         5.00652218]
15 Mar 20:03    INFO  phi = [ 0.         -1.03172055]
15 Mar 20:03    INFO  phi = [ 0.        -5.2641871]
15 Mar 20:03    INFO  phi = [ 0.         -0.80303103]
15 Mar 20:03    INFO  phi = [ 0.         -0.42647219]
15 Mar 20:03    INFO  phi = [0.         3.39504182]
15 Mar 20:03    INFO  phi = [ 0.         -0.45266545]
15 Mar 20:03    INFO  phi = [ 0.         -0.40009332]
15 Mar 20:03    INFO  phi = [0.         1.66476148]
15 Mar 20:03    INFO  phi = [0.         5.20714211]
15 Mar 20:03    INFO  phi = [ 0.        -3.2293855]
15 Mar 20:03    INFO  phi = [0.         0.99271446]
15 Mar 20:03    INFO  phi = [ 0.         -0.21819913]
15 Mar 20:03    INFO  phi = [0.         0.52745152]
15 Mar 20:03    INFO  phi = [ 0.         -0.33610725]
15 Mar 20:03    INFO  phi = [ 0.         -7.86905575]
15 Mar 20:03    INFO  phi = [ 0.         -2.61067092]
15 Mar 20:03    INFO  phi = [0.         3.03786129]
15 Mar 20:03    INFO  phi = [0.         4.39966464]
15 Mar 20:03    INFO  phi = [ 0.         -6.22830749]
15 Mar 20:03    INFO  phi = [ 0.         -0.17829688]
15 Mar 20:03    INFO  phi = [0.         1.65582728]
15 Mar 20:03    INFO  phi = [ 0.         -1.40383154]
15 Mar 20:03    INFO  phi = [0.         1.17879003]
15 Mar 20:03    INFO  phi = [ 0.         -0.03478217]
15 Mar 20:03    INFO  phi = [ 0.         -0.16450088]
15 Mar 20:03    INFO  phi = [ 0.        -0.9192012]
15 Mar 20:03    INFO  phi = [ 0.         -1.48971289]
15 Mar 20:03    INFO  phi = [0.         0.63190988]
15 Mar 20:03    INFO  phi = [0.         2.72020227]
15 Mar 20:03    INFO  phi = [ 0.         -3.19890106]
15 Mar 20:03    INFO  phi = [0.         3.33758008]
15 Mar 20:03    INFO  phi = [ 0.         -3.22817004]
15 Mar 20:03    INFO  phi = [0.         1.81104854]
15 Mar 20:03    INFO  phi = [0.         1.44905442]
15 Mar 20:03    INFO  phi = [0.         1.14345914]
15 Mar 20:03    INFO  phi = [0.         2.63593149]
15 Mar 20:03    INFO  phi = [0.         2.19831732]
15 Mar 20:03    INFO  phi = [ 0.         -4.48448205]
15 Mar 20:03    INFO  phi = [0.         0.18271463]
15 Mar 20:03    INFO  phi = [ 0.         -3.07255054]
15 Mar 20:03    INFO  phi = [0.         0.28144789]
15 Mar 20:03    INFO  phi = [ 0.         -0.39432693]
15 Mar 20:03    INFO  phi = [ 0.         -0.24890184]
15 Mar 20:03    INFO  phi = [0.         2.03434974]
15 Mar 20:03    INFO  phi = [0.         0.22579968]
15 Mar 20:03    INFO  phi = [0.         1.54109418]
15 Mar 20:03    INFO  phi = [0.         5.80791092]
15 Mar 20:03    INFO  phi = [0.         0.16150802]
15 Mar 20:03    INFO  phi = [ 0.         -0.70236762]
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_85342/2944577827.py in <module>
      2     "tol": 0.025,
      3 }
----> 4 coal_plot_data, coal_prun_idx = local_pruning(
      5     f_hs,
      6     formatted_pos_x_data,

~/anaconda3/lib/python3.9/site-packages/timeshap/explainer/pruning.py in local_pruning(f, data, pruning_dict, baseline, entity_uuid, entity_col, verbose)
    268         if baseline is None:
    269             raise ValueError("Baseline is not defined")
--> 270         coal_prun_idx, coal_plot_data = calculate_pruning()
    271         if pruning_dict.get("path") is not None:
    272             # create directory

~/anaconda3/lib/python3.9/site-packages/timeshap/explainer/pruning.py in calculate_pruning()
    255         if baseline is None:
    256             raise ValueError("Baseline is not defined")
--> 257         coal_prun_idx, coal_plot_data = temp_coalition_pruning(f,
    258                                                                data,
    259                                                                baseline,

~/anaconda3/lib/python3.9/site-packages/timeshap/explainer/pruning.py in temp_coalition_pruning(f, data, baseline, tolerance, ret_plot_data, verbose)
    177     for seq_len in range(data.shape[1], -1, -1):
    178         explainer = TimeShapKernel(f, baseline, 0, "pruning")
--> 179         shap_values = explainer.shap_values(data, pruning_idx=seq_len, **{'nsamples': 4})
    180         if ret_plot_data:
    181             plot_data += [['Sum of contribution of events \u003E t', -data.shape[1]+seq_len, shap_values[0]]]

~/anaconda3/lib/python3.9/site-packages/timeshap/explainer/kernel/timeshap_kernel.py in shap_values(self, X, pruning_idx, **kwargs)
    300             out = np.zeros(explanation.shape[0])
    301             if isinstance(explanation.shape, tuple) and len(explanation.shape) == 2:
--> 302                 assert explanation.shape[1] == 1
    303                 out[:] = explanation[:, 0]
    304             else:

AssertionError: 
JoaoPBSousa commented 1 year ago

Hi @ashrimal2,

Could you please provide more details regarding the inputs you are passing to the local_pruning method? Especially the shape of the variables formatted_pos_x_data, average_event_array.

ashrimal2 commented 1 year ago

Hi @JoaoPBSousa , The shape of variables average_event_array and formatted_pos_x_data is (1, 50, 2). Example: array([[[ 143, 215,], [ 143, 737,], [ 143, 951,], [ 143, 262,], [ 143, 269,], [ 143, 140,], [ 143, 116,], [ 143, 961,], [ 143, 1410,], [ 143, 642,], [ 143, 1549,], [ 143, 252,], [ 143, 653,], [ 143, 1561,], [ 143, 840,], [ 143, 421,], [ 143, 1230,], [ 143, 891,], [ 143, 99,], [ 143, 185,], [ 143, 882,], [ 143, 555,], [ 143, 1001,], [ 143, 852,], [ 143, 270,], [ 143, 610,], [ 143, 631,], [ 143, 96,], [ 143, 361,], [ 143, 145,], [ 143, 1262,], [ 143, 816,], [ 143, 275,], [ 143, 1117,], [ 143, 1445,], [ 143, 667,], [ 143, 931,], [ 143, 336,], [ 143, 512,], [ 143, 701,], [ 143, 326,], [ 143, 290,], [ 143, 73,], [ 143, 1167,], [ 143, 131,], [ 143, 398,], [ 143, 1275,], [ 143, 960,], [ 143, 1403,], [ 143, 61]]])

JoaoPBSousa commented 1 year ago

Hi @ashrimal2 ,

We are having trouble replicating your issue. Can you provide a code snippet and all the entry parameters to the method so we can replicate the issue and help you?

JoaoPBSousa commented 1 year ago

Closed this issue due to inactivity. If this error persists or you have any further questions feel free to re-open the issue or create a new one.