sappelhoff / pyprep

A Python implementation of the Preprocessing Pipeline (PREP) for EEG data
https://pyprep.readthedocs.io/en/latest/
MIT License
128 stars 30 forks source link

List of channel names causes TypeError in find_bad_by_ransac #116

Closed OleBialas closed 1 year ago

OleBialas commented 1 year ago

I just noticed a bug that occurs when you pass a list of strings for the complete_chn_labs argument. Here is a reproducible example using a data from MNE:

import os
import numpy as np
import mne
from pyprep.ransac import find_bad_by_ransac

sample_data_folder = mne.datasets.sample.data_path()
sample_data_raw_file = os.path.join(sample_data_folder, 'MEG', 'sample',
                                    'sample_audvis_raw.fif')
raw = mne.io.read_raw_fif(sample_data_raw_file)
# use just 60 seconds of data and mag channels, to save memory
data = raw.pick_types(meg=False, eeg=True).get_data()
samplerate=raw.info['sfreq']
complete_chn_labs = raw.info['ch_names']
chn_pos = np.stack([ch['loc'][0:3] for ch in raw.info['chs']])
bads, corr = find_bad_by_ransac(data, samplerate, complete_chn_labs, chn_pos, exclude=[])

Raises:

TypeError                                 Traceback (most recent call last)
Input In [1], in <cell line: 15>()
     13 complete_chn_labs = raw.info['ch_names']
     14 chn_pos = np.stack([ch['loc'][0:3] for ch in raw.info['chs']])
---> 15 bads, corr = find_bad_by_ransac(data, samplerate, complete_chn_labs, chn_pos, exclude=[])

File ~/miniconda3/envs/eelbrain-cnsp/lib/python3.9/site-packages/pyprep/ransac.py:251, in find_bad_by_ransac(data, sample_rate, complete_chn_labs, chn_pos, exclude, n_samples, sample_prop, corr_thresh, frac_bad, corr_window_secs, channel_wise, max_chunk_size, random_state, matlab_strict)
    249 # find the corresponding channel names and return
    250 bad_ransac_channels_idx = np.argwhere(frac_bad_corr_windows > frac_bad)
--> 251 bad_ransac_channels_name = complete_chn_labs[bad_ransac_channels_idx.astype(int)]
    252 bad_by_ransac = [i[0] for i in bad_ransac_channels_name]
    253 print("\nRANSAC done!")

TypeError: only integer scalar arrays can be converted to a scalar index 

Converting the channel labels to an array fixes the issue:

bads, corr = find_bad_by_ransac(data, samplerate, np.array(complete_chn_labs), chn_pos, exclude=[])

So I guess you could change your code to properly handle lists as input (which would be useful given that MNE stored channel labels as lists) or simply raise a more informative error message

sappelhoff commented 1 year ago

Thanks for the report! Would you like so submit a pull request to solve this issue?

OleBialas commented 1 year ago

Sure, I can do that start of next week