sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 112 forks source link

Explain tree data #173

Open rupjit-bo opened 2 years ago

rupjit-bo commented 2 years ago

Awesome work on the library. I have discovered it recently and it seems to have a lot of good features. I wanted to try it out for tree structured data (Abstract Syntax Trees). I have trained a network called TBCNN. I want to check which nodes in the tree are causing the prediction to be 1 or 0.

My data is stored as a list of dictionaries, where each (nested) dictionary represents a data point (specifically a tree)

[{'node': 'ci_root', 'children': [{'node': 'ci_class_decl', 'children': [{'node': 'ci_modifiers', 'children': [{'node': 'ci_modifier', 'children': []}]}, {'node': 'ci_type', 'children': [{'node': 'SimpleType', 'children': [{'node': 'SimpleName', 'children': []}]}]}, ... ] NOTE: This is just a part of one data point.

Also, I use word2vec style embeddings for each node.

Is there any way in which I can use the tool to identify the nodes affecting the prediction?

Thank you