albermax / innvestigate

A toolbox to iNNvestigate neural networks' predictions!
Other
1.26k stars 233 forks source link

Can this be used for a network with multiple inputs? #233

Closed alideatsch closed 3 years ago

alideatsch commented 3 years ago

I have a CNN that takes in both images and clinical data, can I use this LRP implementation on that network?

If so, any hints on how? I keep getting a KeyError like this:

File "mci_train.py", line 1019, in evaluate_net(seed) File "mci_train.py", line 143, in evaluate_net LRP_analysis = netCNN.LRP_heatmap(test_data, j) File "/data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/utils/models.py", line 121, in LRP_heatmap analysis = analyzer.analyze([[test_mri[img_number]],[test_jac[img_number]],[test_xls[img_number]]]) File "/data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/innvestigate/analyzer/base.py", line 473, in analyze self.create_analyzer_model() File "/data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/innvestigate/analyzer/base.py", line 411, in create_analyzer_model model, stop_analysis_at_tensors=stop_analysis_at_tensors) File "/data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/innvestigate/analyzer/relevance_based/relevance_analyzer.py", line 499, in _create_analysis return super(LRP, self)._create_analysis(*args, **kwargs) File "/data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/innvestigate/analyzer/base.py", line 711, in _create_analysis return_all_reversed_tensors=return_all_reversed_tensors) File "/data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/innvestigate/analyzer/base.py", line 700, in _reverse_model return_all_reversed_tensors=return_all_reversed_tensors) File "/data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/innvestigate/utils/keras/graph.py", line 1143, in reverse_model for tmp in model.inputs File "/data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/innvestigate/utils/keras/graph.py", line 1144, in if tmp not in stop_mapping_at_tensors] File "/data/data_wnx3/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/innvestigate/utils/keras/graph.py", line 1004, in get_reversed_tensor tmp = reversed_tensors[tensor] KeyError: <tf.Tensor 'input_3:0' shape=(?, 91, 109, 91, 1) dtype=float32>

alideatsch commented 3 years ago

Got it figured out. It does work with multiple inputs, I just had to tell it to skip the non-imaging inputs when calculating the reverse tensors.