vlawhern / arl-eegmodels

This is the Army Research Laboratory (ARL) EEGModels Project: A Collection of Convolutional Neural Network (CNN) models for EEG signal classification, using Keras and Tensorflow
Other
1.17k stars 287 forks source link

Problem with the feature explainability methods #35

Open 2bben opened 2 years ago

2bben commented 2 years ago

Hi, I have got the DeepLIFT to work and understood the method, though the two other methods mentioned in [1] have I not managed to implement.

For the first method, summarizing averaged outputs of hidden unit activations:

For the second method, visualizing the convolutional kernel weights:

vlawhern commented 2 years ago

So for the first method, the spatial filters are extracted from the DepthwiseConv2D layer in EEGNet. More specifically,

model = EEGNet(...) # define some EEGNet configuration
model.fit(...)   # fit the model

You can use model.layers to show all the different layers of the model, which should look like this:

[<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7f2f467f7d90>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2f46857850>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f2f468ac310>,
 <tensorflow.python.keras.layers.convolutional.DepthwiseConv2D at 0x7f2f3c329ca0>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f2f3c3291f0>,
 <tensorflow.python.keras.layers.core.Activation at 0x7f2f3c2e11f0>,
 <tensorflow.python.keras.layers.pooling.AveragePooling2D at 0x7f2f3c2e1bb0>,
 <tensorflow.python.keras.layers.core.Dropout at 0x7f2f3c2e1cd0>,
 <tensorflow.python.keras.layers.convolutional.SeparableConv2D at 0x7f2f3c2e9b80>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f2f3c2e9ee0>,
 <tensorflow.python.keras.layers.core.Activation at 0x7f2f3c2f0d90>,
 <tensorflow.python.keras.layers.pooling.AveragePooling2D at 0x7f2f3c34d820>,
 <tensorflow.python.keras.layers.core.Dropout at 0x7f2f3c2e18e0>,
 <tensorflow.python.keras.layers.core.Flatten at 0x7f2f3c2e9190>,
 <tensorflow.python.keras.layers.core.Dense at 0x7f2f3c2f8e20>,
 <tensorflow.python.keras.layers.core.Activation at 0x7f2f3c2f8c10>]

You'll see that the DepthwiseConv2D layer is the 3rd entry in the list (starting from 0), so you can pull the weights of the layer with

model.layers[3].get_weights()

This gets you the spatial filter weights that we then use together with the EEG channel locations to plot a topoplot, which is what we show in Fig 6A in the paper. The spatial filters are not defined for a single time point; rather they are trained using all the data and you learn just one filter for all time points. The number of spatial filters you learn will depend on the EEGNet model configuration you train; EEGNet-8,2 specifically learns 2 spatial filters for each of 8 temporal filters so a total of 16 spatial filters.

For the second method, the convolutional kernel filter weights (Fig 7, top row) are from the first Conv2D layer which represents the temporal filter layer.

model.layers[1].get_weights()

The middle and bottom rows are the spatial filter weights, using the method to extract the weights described above.

Figure 8 shows spatial filters from two different methods, Filter-Bank CSP (https://www.frontiersin.org/articles/10.3389/fnins.2012.00039/full) and EEGNet.

Hope this helps..