vlawhern / arl-eegmodels

This is the Army Research Laboratory (ARL) EEGModels Project: A Collection of Convolutional Neural Network (CNN) models for EEG signal classification, using Keras and Tensorflow
Other
1.14k stars 284 forks source link

Deepexpalin issue #29

Closed Alan9890 closed 3 years ago

Alan9890 commented 3 years ago

Does anyone of you know how to solve it?

Captura

vlawhern commented 3 years ago

This is due to the transition to Tensorflow 2, and the original DeepExplain package not supporting TF2 out of the box. There is an open pull request (https://github.com/marcoancona/DeepExplain/pull/55) that provides support for TF2 as long as you disable eager execution:

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

...
(the rest of your code)
...

Here's a code snippet that works out-of-the-box with the above pull request (using the MNE sample dataset):


# import tensorflow and disable eager execution right up front
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

import numpy as np

# mne imports
import mne
from mne import io
from mne.datasets import sample

# EEGNet-specific imports
from EEGModels import EEGNet
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.python.keras import backend as K
from tensorflow.keras.models import Model

from deepexplain.tensorflow import DeepExplain

# while the default tensorflow ordering is 'channels_last' we set it here
# to be explicit in case if the user has changed the default ordering
K.set_image_data_format('channels_last')

##################### Process, filter and epoch the data ######################
data_path = sample.data_path()

# Set parameters and read data
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
tmin, tmax = -0., 1
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)

# Setup for reading the raw data
raw = io.Raw(raw_fname, preload=True, verbose=False)
raw.filter(2, None, method='iir')  # replace baselining with high-pass
events = mne.read_events(event_fname)

raw.info['bads'] = ['MEG 2443']  # set bad channels
picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                       exclude='bads')

# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,
                    picks=picks, baseline=None, preload=True, verbose=False)
labels = epochs.events[:, -1]

# extract raw data. scale by 1000 due to scaling sensitivity in deep learning
X = epochs.get_data()*1000 # format is in (trials, channels, samples)
y = labels

kernels, chans, samples = 1, 60, 151

# take 50/25/25 percent of the data to train/validate/test
X_train      = X[0:144,]
Y_train      = y[0:144]
X_validate   = X[144:216,]
Y_validate   = y[144:216]
X_test       = X[216:,]
Y_test       = y[216:]

# convert labels to one-hot encodings.
Y_train      = np_utils.to_categorical(Y_train-1)
Y_validate   = np_utils.to_categorical(Y_validate-1)
Y_test       = np_utils.to_categorical(Y_test-1)

# convert data to NHWC (trials, channels, samples, kernels) format. Data 
# contains 60 channels and 151 time-points. Set the number of kernels to 1.
X_train      = X_train.reshape(X_train.shape[0], chans, samples, kernels)
X_validate   = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
X_test       = X_test.reshape(X_test.shape[0], chans, samples, kernels)

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

# configure the EEGNet-8,2,16 model with kernel length of 32 samples (other 
# model configurations may do better, but this is a good starting point)
model = EEGNet(nb_classes = 4, Chans = chans, Samples = samples, 
               dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16, 
               dropoutType = 'Dropout')

# compile the model and set the optimizers
model.compile(loss='categorical_crossentropy', optimizer='adam', 
              metrics = ['accuracy'])

# count number of parameters in the model
numParams    = model.count_params()    

# set a valid path for your system to record model checkpoints
checkpointer = ModelCheckpoint(filepath='/tmp/checkpoint.h5', verbose=1,
                               save_best_only=True)

###############################################################################
# if the classification task was imbalanced (significantly more trials in one
# class versus the others) you can assign a weight to each class during 
# optimization to balance it out. This data is approximately balanced so we 
# don't need to do this, but is shown here for illustration/completeness. 
###############################################################################

# the syntax is {class_1:weight_1, class_2:weight_2,...}. Here just setting
# the weights all to be 1
class_weights = {0:1, 1:1, 2:1, 3:1}

fittedModel = model.fit(X_train, Y_train, batch_size = 16, epochs = 5, 
                        verbose = 2, validation_data=(X_validate, Y_validate),
                        callbacks=[checkpointer], class_weight = class_weights)

with DeepExplain(session = K.get_session()) as de:
    input_tensor   = model.layers[0].input
    fModel         = Model(inputs = input_tensor, outputs = model.layers[-2].output)    
    target_tensor  = fModel(input_tensor)    

    # can use epsilon-LRP as well if you like.
    attributions   = de.explain('deeplift', target_tensor * Y_test, input_tensor, X_test)
    # attributions = de.explain('elrp', target_tensor * Y_test, input_tensor, X_test)   
vlawhern commented 3 years ago

Alternatively, you could manually fix this by editing /deepexplain/tensorflow/methods.py directly, although this is a pretty bad hack:

  1. Change tf.placeholder with tf.compat.v1.placeholder
  2. Change tf.get_default_graph with tf.compat.v1.get_default_graph
  3. Change tf.get_default_session with tf.compat.v1.get_default_session

I've verified this also works (not extensively tested however), although the above PR is the better route.

Alan9890 commented 3 years ago

Very good that works thank you!