sappelhoff / pyprep

PyPREP: A Python implementation of the Preprocessing Pipeline (PREP) for EEG data
https://pyprep.readthedocs.io/en/latest/
MIT License
136 stars 33 forks source link

find_all_bads throws inf/nan error after filtering/detrending? #129

Closed sunshineinsandiego closed 1 year ago

sunshineinsandiego commented 1 year ago

Hi - I am trying to find and eliminate all bad channels from a 19ch EEG array. When I attempt to use the find_all_bads function, I receive the following error:

>>> import mne
>>> import numpy as np
>>> from find_noisy_channels import NoisyChannels

>>> data.shape 
(19, 625773)
>>> np.isnan(data).any()
False
>>> np.isinf(data).any()
False
>>> np.isfinite(data).all()
True

>>> mne_info = mne.create_info(ch_names = channels, sfreq = eeg_stream_srate, ch_types = 'eeg')
>>> mne_data = mne.io.RawArray(data, mne_info)

>>> noisy_channels = NoisyChannels(mne_data, random_state = 1337)
>>> noisy_channels.find_all_bads(ransac = True, channel_wise = True)

Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 1651 samples (3.302 sec)

Output exceeds the [size limit](command:workbench.action.openSettings?[). Open the full output data [in a text editor](command:workbench.action.openLargeOutput?1b247b6c-4e92-4d03-9c32-03410fec8a82)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
...

File .../find_noisy_channels.py:222, in NoisyChannels.find_all_bads(self, ransac, channel_wise, max_chunk_size)
    220 self.find_bad_by_SNR()
    221 if ransac:
--> 222     self.find_bad_by_ransac(
    223         channel_wise=channel_wise, max_chunk_size=max_chunk_size
    224     )

File .../find_noisy_channels.py:577, in NoisyChannels.find_bad_by_ransac(self, n_samples, sample_prop, corr_thresh, frac_bad, corr_window_secs, channel_wise, max_chunk_size)
    573 exclude_from_ransac = (
    574     self.bad_by_correlation + self.bad_by_deviation + self.bad_by_dropout
    575 )
    576 rng = copy(self.random_state) if self.matlab_strict else self.random_state
--> 577 self.bad_by_ransac, ch_correlations_usable = find_bad_by_ransac(
    578     self.EEGFiltered,
    579     self.sample_rate,
    580     self.ch_names_new,
    581     self.raw_mne._get_channel_positions()[self.usable_idx, :],
...
--> 603     raise ValueError(
    604         "array must not contain infs or NaNs")
    605 return a

ValueError: array must not contain infs or NaNs

Because the input data array does not contain any nans or infs, my assumption is that some part of the detrend/filtering within the find_all_bads() function creates inf or nan values that can not then be further processed.

But when I detrend/filter manually, I still have no nans or infs

>>> from removeTrend import removeTrend
>>> data_detrend = removeTrend(mne_data.get_data(), sample_rate = eeg_stream_srate, matlab_strict = False)
>>> np.isfinite(data_detrend).all()
True

If I examine the full stack trace, it looks like the error is in the call to make_interpolation_matrices in ransac.py which is relying on mne to call the scipy pseudo-inverse function?

  Input In [30] in <cell line: 3>
    noisy_channels.find_all_bads(ransac = True, channel_wise = True)

  File .../find_noisy_channels.py:222 in find_all_bads
    self.find_bad_by_ransac(

  File .../find_noisy_channels.py:577 in find_bad_by_ransac
    self.bad_by_ransac, ch_correlations_usable = find_bad_by_ransac(

  File .../ransac.py:162 in find_bad_by_ransac
    interp_mats = _make_interpolation_matrices(random_ch_picks, chn_pos_good)

  File .../ransac.py:284 in _make_interpolation_matrices
    mat[:, sample] = _make_interpolation_matrix(subset_pos, chn_pos_good)

  File ~/miniconda3/envs/.../lib/python3.9/site-packages/mne/channels/interpolation.py:108 in _make_interpolation_matrix
    C_inv = linalg.pinv(C)

  File ~/miniconda3/envs/.../lib/python3.9/site-packages/scipy/linalg/basic.py:1315 in pinv
    a = _asarray_validated(a, check_finite=check_finite)

  File ~/miniconda3/envs/.../lib/python3.9/site-packages/scipy/_lib/_util.py:293 in _asarray_validated
    a = toarray(a)

  File ~/miniconda3/envs/.../lib/python3.9/site-packages/numpy/lib/function_base.py:603 in asarray_chkfinite
    raise ValueError(

ValueError: array must not contain infs or NaNs

Where are the nans/infs coming from and suggestions for how to address this?

Thanks!

sappelhoff commented 1 year ago

Thanks for the report! Have you tried this with other data that is comparable in time length and channel number as well and ran into the same issue?

I could imagine that it has something to do with the few amount of channels, but I am not sure and even if that is the case, there should be a proper error message then.

sappelhoff commented 1 year ago

one thing to check is the exclude_from_ransac variable --> how many of the 19 channels have already been identified as bad before you even start to go through a ransac routine?

https://github.com/sappelhoff/pyprep/blob/c6b268b90e08af2d4740496414cad649344345fa/pyprep/find_noisy_channels.py#L573-L592

sunshineinsandiego commented 1 year ago

Thanks, if I add print(exclude_from_ransac) here I end up with the following output: ['Fp1', 'Fp2', 'F8', 'Cz', 'T3', 'Cz'], so it looks like 6 channels are dropped pre-ransac routine.

If I try with a different EEG set, another set of 6 channels is identified to exclude from ransac, and the same error occurs:

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
...
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 1651 samples (3.302 sec)

['F7', 'Fp1', 'Fp2', 'T3', 'T4', 'Fp1']

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [6], in <cell line: 4>()
      2 from find_noisy_channels import NoisyChannels
      3 noisy_channels = NoisyChannels(mne_data, random_state = 1337)
----> 4 noisy_channels.find_all_bads(ransac = True, channel_wise = True)

File /mnt/work/.../find_noisy_channels.py:222, in NoisyChannels.find_all_bads(self, ransac, channel_wise, max_chunk_size)
    220 self.find_bad_by_SNR()
    221 if ransac:
--> 222     self.find_bad_by_ransac(
    223         channel_wise=channel_wise, max_chunk_size=max_chunk_size
    224     )

File /mnt/work/.../find_noisy_channels.py:578, in NoisyChannels.find_bad_by_ransac(self, n_samples, sample_prop, corr_thresh, frac_bad, corr_window_secs, channel_wise, max_chunk_size)
    576 print(exclude_from_ransac)
    577 rng = copy(self.random_state) if self.matlab_strict else self.random_state
--> 578 self.bad_by_ransac, ch_correlations_usable = find_bad_by_ransac(
    579     self.EEGFiltered,
    580     self.sample_rate,
    581     self.ch_names_new,
    582     self.raw_mne._get_channel_positions()[self.usable_idx, :],
    583     exclude_from_ransac,
    584     n_samples,
    585     sample_prop,
    586     corr_thresh,
    587     frac_bad,
    588     corr_window_secs,
    589     channel_wise,
    590     max_chunk_size,
    591     rng,
    592     self.matlab_strict,
    593 )
    595 # Reshape correlation matrix to match original channel count
    596 n_ransac_windows = ch_correlations_usable.shape[0]

File /mnt/work/.../ransac.py:162, in find_bad_by_ransac(data, sample_rate, complete_chn_labs, chn_pos, exclude, n_samples, sample_prop, corr_thresh, frac_bad, corr_window_secs, channel_wise, max_chunk_size, random_state, matlab_strict)
    159     random_ch_picks.append(picks)
    161 # Generate interpolation matrix for each RANSAC sample
--> 162 interp_mats = _make_interpolation_matrices(random_ch_picks, chn_pos_good)
    164 # Calculate the size (in frames) and count of correlation windows
    165 correlation_frames = corr_window_secs * sample_rate

File /mnt/work/.../ransac.py:284, in _make_interpolation_matrices(random_ch_picks, chn_pos_good)
    282     mat = np.zeros((n_chans_good, n_chans_good))
    283     subset_pos = chn_pos_good[sample, :]
--> 284     mat[:, sample] = _make_interpolation_matrix(subset_pos, chn_pos_good)
    285     interpolation_mats.append(mat)
    287 return interpolation_mats

File ~/miniconda3/.../CFS/lib/python3.9/site-packages/mne/channels/interpolation.py:108, in _make_interpolation_matrix(pos_from, pos_to, alpha)
    104     G_from.flat[::len(G_from) + 1] += alpha
    106 C = np.vstack([np.hstack([G_from, np.ones((n_from, 1))]),
    107                np.hstack([np.ones((1, n_from)), [[0]]])])
--> 108 C_inv = linalg.pinv(C)
    110 interpolation = np.hstack([G_to_from, np.ones((n_to, 1))]) @ C_inv[:, :-1]
    111 assert interpolation.shape == (n_to, n_from)

File ~/miniconda3/.../CFS/lib/python3.9/site-packages/scipy/linalg/basic.py:1315, in pinv(a, atol, rtol, return_rank, check_finite, cond, rcond)
   1241 def pinv(a, atol=None, rtol=None, return_rank=False, check_finite=True,
   1242          cond=None, rcond=None):
   1243     """
   1244     Compute the (Moore-Penrose) pseudo-inverse of a matrix.
   1245 
   (...)
   1313 
   1314     """
-> 1315     a = _asarray_validated(a, check_finite=check_finite)
   1316     u, s, vh = decomp_svd.svd(a, full_matrices=False, check_finite=False)
   1317     t = u.dtype.char.lower()

File ~/miniconda3/.../CFS/lib/python3.9/site-packages/scipy/_lib/_util.py:293, in _asarray_validated(a, check_finite, sparse_ok, objects_ok, mask_ok, as_inexact)
    291         raise ValueError('masked arrays are not supported')
    292 toarray = np.asarray_chkfinite if check_finite else np.asarray
--> 293 a = toarray(a)
    294 if not objects_ok:
    295     if a.dtype is np.dtype('O'):

File ~/miniconda3/envs/.../lib/python3.9/site-packages/numpy/lib/function_base.py:603, in asarray_chkfinite(a, dtype, order)
    601 a = asarray(a, dtype=dtype, order=order)
    602 if a.dtype.char in typecodes['AllFloat'] and not np.isfinite(a).all():
...
--> 603     raise ValueError(
    604         "array must not contain infs or NaNs")
    605 return a

ValueError: array must not contain infs or NaNs
sappelhoff commented 1 year ago

Mmmhh, seems like we have to dig into the code a bit more. Two requests:

  1. can you pick some unrelated data, prune it down to 19 channels, and run it?
  2. can you share your data (e.g. as .npy if you are using arrays) AND a minimal working example with me to replicate the error? I could then spend some time digging into this as well.
sunshineinsandiego commented 1 year ago

Sure, thanks for your help!

  1. Yes, tried with a few different data sets, same error.
  2. example .npy zipped data file attached.
import numpy as np
import mne
from find_noisy_channels import NoisyChannels

channels = ['F7', 'Fp1', 'Fp2', 'F8', 'F3', 'Fz', 'F4', 'C3', 'Cz', 'P8', 'P7', 'Pz', 'P4', 'T3', 'P3', 'O1', 'O2', 'C4', 'T4']
eeg_stream_srate = 500

data = np.load('data1.npy')

mne_info = mne.create_info(ch_names = channels, sfreq = eeg_stream_srate, ch_types = 'eeg')
mne_data = mne.io.RawArray(data, mne_info)

noisy_channels = NoisyChannels(mne_data, random_state = 1337)
noisy_channels.find_all_bads(ransac = True, channel_wise = True)

With the following output:

Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
...
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 1651 samples (3.302 sec)

exclude_from_ransac: ['F7', 'Fp1', 'Fp2', 'F8', 'Cz', 'T3', 'T4', 'Cz']

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)

File .../find_noisy_channels.py:222, in NoisyChannels.find_all_bads(self, ransac, channel_wise, max_chunk_size)
    220 self.find_bad_by_SNR()
    221 if ransac:
--> 222     self.find_bad_by_ransac(
    223         channel_wise=channel_wise, max_chunk_size=max_chunk_size
    224     )

File .../find_noisy_channels.py:578, in NoisyChannels.find_bad_by_ransac(self, n_samples, sample_prop, corr_thresh, frac_bad, corr_window_secs, channel_wise, max_chunk_size)
    576 print('exclude_from_ransac: {}'.format(exclude_from_ransac))
    577 rng = copy(self.random_state) if self.matlab_strict else self.random_state
--> 578 self.bad_by_ransac, ch_correlations_usable = find_bad_by_ransac(
    579     self.EEGFiltered,
    580     self.sample_rate,
    581     self.ch_names_new,
    582     self.raw_mne._get_channel_positions()[self.usable_idx, :],
    583     exclude_from_ransac,
    584     n_samples,
    585     sample_prop,
    586     corr_thresh,
    587     frac_bad,
    588     corr_window_secs,
    589     channel_wise,
    590     max_chunk_size,
    591     rng,
    592     self.matlab_strict,
    593 )
    595 # Reshape correlation matrix to match original channel count
    596 n_ransac_windows = ch_correlations_usable.shape[0]

File .../ransac.py:162, in find_bad_by_ransac(data, sample_rate, complete_chn_labs, chn_pos, exclude, n_samples, sample_prop, corr_thresh, frac_bad, corr_window_secs, channel_wise, max_chunk_size, random_state, matlab_strict)
    159     random_ch_picks.append(picks)
    161 # Generate interpolation matrix for each RANSAC sample
--> 162 interp_mats = _make_interpolation_matrices(random_ch_picks, chn_pos_good)
    164 # Calculate the size (in frames) and count of correlation windows
    165 correlation_frames = corr_window_secs * sample_rate

File..../ransac.py:284, in _make_interpolation_matrices(random_ch_picks, chn_pos_good)
    282     mat = np.zeros((n_chans_good, n_chans_good))
    283     subset_pos = chn_pos_good[sample, :]
--> 284     mat[:, sample] = _make_interpolation_matrix(subset_pos, chn_pos_good)
    285     interpolation_mats.append(mat)
    287 return interpolation_mats

File ~/miniconda3/envs/.../lib/python3.9/site-packages/mne/channels/interpolation.py:108, in _make_interpolation_matrix(pos_from, pos_to, alpha)
    104     G_from.flat[::len(G_from) + 1] += alpha
    106 C = np.vstack([np.hstack([G_from, np.ones((n_from, 1))]),
    107                np.hstack([np.ones((1, n_from)), [[0]]])])
--> 108 C_inv = linalg.pinv(C)
    110 interpolation = np.hstack([G_to_from, np.ones((n_to, 1))]) @ C_inv[:, :-1]
    111 assert interpolation.shape == (n_to, n_from)

File ~/miniconda3/envs/.../lib/python3.9/site-packages/scipy/linalg/basic.py:1315, in pinv(a, atol, rtol, return_rank, check_finite, cond, rcond)
   1241 def pinv(a, atol=None, rtol=None, return_rank=False, check_finite=True,
   1242          cond=None, rcond=None):
   1243     """
   1244     Compute the (Moore-Penrose) pseudo-inverse of a matrix.
   1245 
   (...)
   1313 
   1314     """
-> 1315     a = _asarray_validated(a, check_finite=check_finite)
   1316     u, s, vh = decomp_svd.svd(a, full_matrices=False, check_finite=False)
   1317     t = u.dtype.char.lower()

File ~/miniconda3/envs/.../lib/python3.9/site-packages/scipy/_lib/_util.py:293, in _asarray_validated(a, check_finite, sparse_ok, objects_ok, mask_ok, as_inexact)
    291         raise ValueError('masked arrays are not supported')
    292 toarray = np.asarray_chkfinite if check_finite else np.asarray
--> 293 a = toarray(a)
    294 if not objects_ok:
    295     if a.dtype is np.dtype('O'):

File ~/miniconda3/envs/.../lib/python3.9/site-packages/numpy/lib/function_base.py:603, in asarray_chkfinite(a, dtype, order)
    601 a = asarray(a, dtype=dtype, order=order)
    602 if a.dtype.char in typecodes['AllFloat'] and not np.isfinite(a).all():
...
--> 603     raise ValueError(
    604         "array must not contain infs or NaNs")
    605 return a

ValueError: array must not contain infs or NaNs

data1.zip

sappelhoff commented 1 year ago

Thanks, I can reproduce your error. Are you sure the data is correctly scaled? When trying to plot it, I get this:

image

Note the scaling:

image

@sunshineinsandiego the data that you pass to mne_data = mne.io.RawArray(data, mne_info) MUST be in Volt. See also the documentation: https://mne.tools/stable/generated/mne.io.RawArray.html

Some questions:

sunshineinsandiego commented 1 year ago

Thanks:

I don't do any pre-processing on the data before dumping it into the PREP pipeline. Thanks for the note on volts, but I'm not sure scaling is the issue? F7 and other channels that are outliers are excluded from ransac exclude_from_ransac prior to running the ransac function.

sappelhoff commented 1 year ago

What is the original output format of the Cognionics acquisition software? I am wondering why you have to read the data as npy array and create an mne data structure from scratch.

but I'm not sure scaling is the issue

I am also not so sure, but debugging is easier when we tackle the obvious and easy problems first, and then progress --> do you not agree that the scaling of your data seems off after loading it into an mne data structure?

sunshineinsandiego commented 1 year ago

Original output is an LSL stream which is parsed as a raw array.

Agree, the scaling is off, but I can re-scale all channels to volts without fixing the error.

sappelhoff commented 1 year ago

okay, the solution was quite simple: you don't have channel positions in your raw object, try:

montage = mne.channels.make_standard_montage("standard_1020")
mne_data.set_montage(montage)
noisy_channels = NoisyChannels(mne_data, random_state=1337)
noisy_channels.find_all_bads(ransac=True, channel_wise=True)

We definitely need a better error message for this

sunshineinsandiego commented 1 year ago

Amazing, thanks for figuring that out!

sappelhoff commented 1 year ago

Thanks for your report 👍