vferat / pycrostates

https://pycrostates.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
36 stars 11 forks source link

[ENH] Add a spatial filter #6

Closed vferat closed 1 year ago

vferat commented 2 years ago

Add a preprocessing function to spatialy smooth data (raw, epochs, evoked?). Something similar to the spatial filter option in cartool image

vferat commented 2 years ago

mne.channels.find_ch_adjacency(raw.info, ch_type='eeg')

vferat commented 1 year ago

Something like this should do the work:

import numpy as np
import mne
from mne.channels.interpolation import _make_interpolation_matrix

picks = 'eeg'
inst = raw_EO
data = inst.get_data()

pos = inst._get_channel_positions(picks)
interpolate_matrix = _make_interpolation_matrix(pos, pos)

M, ch_names = mne.channels.find_ch_adjacency(inst.info, picks)
M = M.todense()
M = np.array(M)

new_data = np.zeros(data.shape)
for m, mat in enumerate(M):
    neigbhours_data = data[mat==1, :]
    neigbhour_indices = np.argwhere(mat == 1)
    # neigbhour_matrix shape (n_neigbhour, n_samples)
    neigbhour_matrix = np.array([neigbhour_indices.flatten().tolist()] *  data.shape[-1]).T

    # Create a mask
    max_mask = (neigbhours_data == np.amax(neigbhours_data, keepdims=True, axis=0))
    min_mask = (neigbhours_data == np.amin(neigbhours_data, keepdims=True, axis=0))
    keep_mask = ~(max_mask | min_mask)

    keep_indices = np.array([neigbhour_matrix[:,i][keep_mask[:,i]] for i in range(keep_mask.shape[-1])])
    for i, keep_ind in enumerate(keep_indices):
        weights = interpolate_matrix[keep_ind, m] 
        if len(weights) == 0:
            continue
        weights =  weights / np.linalg.norm(weights)
        interp_channel_data = np.average(data[keep_ind, i],weights=weights)
        new_data[m, i] = interp_channel_data
new_raw = mne.io.RawArray(info=inst.info, data=new_data)

image

Generated from random sample of a Raw EEG recording:

import matplotlib.pyplot as plt 
n = 10
random_sample = np.random.randint(0, inst.n_times, n)
sphere=np.array([0,0,0,1])
fig, axes = plt.subplots(nrows=2, ncols=n, figsize=(10,2))
for s,sample in enumerate(random_sample):
    mne.viz.topomap._plot_topomap(inst.get_data()[:, sample], pos=new_raw.info, axes=axes[0, s], sphere=sphere, show=False)
    mne.viz.topomap._plot_topomap(new_raw.get_data()[:, sample], pos=new_raw.info, axes=axes[1, s], sphere=sphere, show=False)

np.diag(np.corrcoef(inst.get_data(), new_raw.get_data())[61:, :61])