eeyhsong / EEG-Conformer

EEG Transformer 2.0. i. Convolutional Transformer for EEG Decoding. ii. Novel visualization - Class Activation Topography.
GNU General Public License v3.0
389 stars 55 forks source link

Regarding target category in visualization scripts #17

Open rmib200 opened 1 year ago

rmib200 commented 1 year ago

It seems that the visualization script is missing a way to select which target category to visualize. In the CAT.py script, a variable called target_category exists but is not used anywhere. While reviewing the utils.py script I found some references to a similarly named variable but it is set to None . I was wondering if the script lacks an exposed argument to the GradCAM class where the target_category can be defined. I would really appreciate your help to define which categories to plot, please since I am trying to reproduce the results.

eeyhsong commented 1 year ago

Hello! @rmib200, I calculated CAM for each EEG trial and chose the trials in one category for the mean CAM. 🤝

rmib200 commented 1 year ago

I been trying to solve this for a couple days now, can you please point me to where are you choosing the trials in one category? I belive that I may be missing something but cannot get my head around what. Please, if you can be so kind, this is a extract of the CAT.py script we the cam is calculated.

data = np.load('./grad_cam/train_data.npy')  
print(np.shape(data))

nSub = 1
target_category = 2  # set the class (class activation mapping) <--------- this is I supose target category, but it is not used anywhere in the code

def reshape_transform(tensor):
    result = rearrange(tensor, 'b (h w) e -> b e (h) (w)', h=1)
    return result

device = torch.device("cpu")
model = ViT()

model.load_state_dict(torch.load('./model/sub%d.pth'%nSub, map_location=device))
target_layers = [model[1]]  # set the target layer 
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False, reshape_transform=reshape_transform)

# TODO: Class Activation Topography (proposed in the paper)
import mne
from matplotlib import mlab as mlab

biosemi_montage = mne.channels.make_standard_montage('biosemi64')
index = [37, 9, 10, 46, 45, 44, 13, 12, 11, 47, 48, 49, 50, 17, 18, 31, 55, 54, 19, 30, 56, 29]  # for bci competition iv 2a
biosemi_montage.ch_names = [biosemi_montage.ch_names[i] for i in index]
biosemi_montage.dig = [biosemi_montage.dig[i+3] for i in index]
info = mne.create_info(ch_names=biosemi_montage.ch_names, sfreq=250., ch_types='eeg')

all_cam = []
# this loop is used to obtain the cam of each trial/sample
for i in range(288):
    test = torch.as_tensor(data[i:i+1, :, :, :], dtype=torch.float32)
    test = torch.autograd.Variable(test, requires_grad=True)

    grayscale_cam = cam(input_tensor=test)
    grayscale_cam = grayscale_cam[0, :]
    all_cam.append(grayscale_cam)

# the mean of all data
test_all_data = np.squeeze(np.mean(data, axis=0))
mean_all_test = np.mean(test_all_data, axis=1)

# the mean of all cam
test_all_cam = np.mean(all_cam, axis=0)
mean_all_cam = np.mean(test_all_cam, axis=1)

# apply cam on the input data
hyb_all = test_all_data * test_all_cam
mean_hyb_all = np.mean(hyb_all, axis=1)

evoked = mne.EvokedArray(test_all_data, info)
evoked.set_montage(biosemi_montage)

fig, [ax1, ax2] = plt.subplots(nrows=2)

plt.subplot(211)
im1, cn1 = mne.viz.plot_topomap(mean_all_test, evoked.info, show=False, axes=ax1, res=1200)

plt.subplot(212)
im2, cn2 = mne.viz.plot_topomap(mean_hyb_all, evoked.info, show=False, axes=ax2, res=1200)

`

eeyhsong commented 1 year ago

@rmib200 Hello, so sorry for the late reply. The all_cam has been obtained in the middle of the code. Then we can get the cam results corresponding to different categories with the ground truth label. Is that help?

rmib200 commented 1 year ago

You mean by filtering the all_cam with only the target_category? I've tried doing something like this: filtered_list = [x for x, y in zip(all_cam, y_labels) if y == target_category]

and then calculate the mean across all trial, all CAMs. Final code looks something like this:

    # Create a new list by selecting elements from list1 where corresponding elements in list2 are 1
    filtered_all_cam = [x for x, y in zip(all_cam, y_test_encoded) if y == target_category]
    # Calculate mean across all trials/samples
    test_all_data = np.squeeze(np.mean(data, axis=0))
    mean_all_test = np.mean(test_all_data, axis=1)

    # Calculate mean across all CAMs
    # test_all_cam = np.mean(all_cam, axis=0)
    test_all_cam = np.mean(filtered_all_cam, axis=0)
    # print("test_all_data", test_all_data.shape)
    mean_all_cam = np.mean(test_all_cam, axis=1)

    # Apply CAM on the input data
    hyb_all = test_all_data * test_all_cam
    mean_hyb_all = np.mean(hyb_all, axis=1)

Is this approach correct?