mne-tools / mne-connectivity

Connectivity algorithms that leverage the MNE-Python API.
https://mne.tools/mne-connectivity/dev/index.html
BSD 3-Clause "New" or "Revised" License
66 stars 34 forks source link

error running spectral_connectivity_time on two selected channels #153

Closed rythorpe closed 8 months ago

rythorpe commented 8 months ago

I've been trying to use mne_connectivity.spectral_connectivity_time to calculate the Granger causality between two selected channels of data but got the following traceback snippet. I suspect there's something wrong with how the indices are getting used to select data channels, but it's also possible that I'm misunderstanding what this variable does. (My intention is to calculate the GC from the 1st to 3rd channel.)

Estimated data ranks:
    connection 1 - seeds (1); targets (1)
Connectivity computation...
   Processing epoch 1 / 30 ...
Computing GC for connection 1 of 1
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
File ~/Dropbox (Brown)/wm_spec_events/err_demo.py:21
     17 freqs = np.arange(fmin, fmax + 1, 1.0)
     18 indices = (np.array([0]),
     19            np.array([2]))
---> 21 con = spectral_connectivity_time(epochs, freqs=freqs, method="gc",
     22                                  indices=indices, mode="cwt_morlet",
     23                                  fmin=fmin, fmax=fmax, faverage=True,
     24                                  gc_n_lags=25)

File <decorator-gen-620>:12, in spectral_connectivity_time(data, freqs, method, average, indices, sfreq, fmin, fmax, fskip, faverage, sm_times, sm_freqs, sm_kernel, padding, mode, mt_bandwidth, n_cycles, gc_n_lags, rank, decim, n_jobs, verbose)

File ~/.local/lib/python3.9/site-packages/mne_connectivity/spectral/time.py:498, in spectral_connectivity_time(data, freqs, method, average, indices, sfreq, fmin, fmax, fskip, faverage, sm_times, sm_freqs, sm_kernel, padding, mode, mt_bandwidth, n_cycles, gc_n_lags, rank, decim, n_jobs, verbose)
    496 for epoch_idx in np.arange(n_epochs):
    497     logger.info(f'   Processing epoch {epoch_idx+1} / {n_epochs} ...')
--> 498     scores, patterns = _spectral_connectivity(data[epoch_idx],
    499                                               **call_params)
    500     for m in method:
    501         conn[m][epoch_idx] = np.stack(scores[m], axis=0)

File ~/.local/lib/python3.9/site-packages/mne_connectivity/spectral/time.py:667, in _spectral_connectivity(data, method, kernel, foi_idx, source_idx, target_idx, signals_use, mode, sfreq, freqs, faverage, n_cycles, mt_bandwidth, gc_n_lags, rank, decim, padding, kw_cwt, kw_mt, n_jobs, verbose, multivariate_con)
    665 scores = {}
    666 patterns = {}
--> 667 conn = _parallel_con(out, method, kernel, foi_idx, source_idx, target_idx,
    668                      signals_use, gc_n_lags, rank, n_jobs, verbose,
    669                      n_pairs, faverage, weights, multivariate_con)
    670 for i, m in enumerate(method):
    671     if multivariate_con:

File ~/.local/lib/python3.9/site-packages/mne_connectivity/spectral/time.py:762, in _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, signals_use, gc_n_lags, rank, n_jobs, verbose, total, faverage, weights, multivariate_con)
    755     parallel, my_pairwise_con, n_jobs = parallel_func(
    756         _pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total)
    758     return tuple(parallel(
    759         my_pairwise_con(w, psd, s, t, method, kernel, foi_idx, faverage,
    760                         weights) for s, t in zip(source_idx, target_idx)))
--> 762 return _multivariate_con(w, source_idx, target_idx, signals_use, method,
    763                          kernel, foi_idx, faverage, weights, gc_n_lags,
    764                          rank, n_jobs)

File ~/.local/lib/python3.9/site-packages/mne_connectivity/spectral/time.py:898, in _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, foi_idx, faverage, weights, gc_n_lags, rank, n_jobs)
    896 patterns = []
    897 for con_est in conn:
--> 898     con_est.compute_con(np.array([source_idx, target_idx]), rank)
    899     scores.append(con_est.con_scores[..., np.newaxis])
    900     patterns.append(con_est.patterns)

File ~/.local/lib/python3.9/site-packages/mne_connectivity/spectral/epochs.py:896, in _GCEstBase.compute_con(self, indices, ranks, n_epochs)
    893 self._log_connection_number(con_i)
    895 con_idcs = [*seed_idcs, *target_idcs]
--> 896 C = csd[np.ix_(times, freqs, con_idcs, con_idcs)]
    898 con_seeds = np.arange(len(seed_idcs))
    899 con_targets = np.arange(len(target_idcs)) + len(seed_idcs)

IndexError: index 2 is out of bounds for axis 2 with size 2

Here's a MWE using the same data referenced in your Granger causality example.

import numpy as np

import mne
from mne.datasets.fieldtrip_cmc import data_path
from mne_connectivity import spectral_connectivity_time

raw = mne.io.read_raw_ctf(data_path() / 'SubjectCMC.ds')
raw.pick('mag')
raw.crop(50., 110.).load_data()
raw.notch_filter(50)
raw.resample(100)

epochs = mne.make_fixed_length_epochs(raw, duration=2.0).load_data()

fmin, fmax = 15.0, 29.0
freqs = np.arange(fmin, fmax + 1, 1.0)
indices = (np.array([0]),
           np.array([2]))

con = spectral_connectivity_time(epochs, freqs=freqs, method="gc",
                                 indices=indices, mode="cwt_morlet",
                                 fmin=fmin, fmax=fmax, faverage=True,
                                 gc_n_lags=25)

Here's info about my system and MNE installation:

Platform Linux-5.15.0-87-generic-x86_64-with-glibc2.35 Python 3.9.18 (main, Sep 11 2023, 13:41:44) [GCC 11.2.0] Executable /home/ryan/anaconda3/envs/mne/bin/python CPU x86_64 (12 cores) Memory 15.5 GB

Core ├☑ mne 1.5.1 ├☑ numpy 1.26.0 (MKL 2023.1-Product with 6 threads) ├☑ scipy 1.11.3 ├☑ matplotlib 3.7.2Installed qt5 event loop hook. (backend=QtAgg) ├☑ pooch 1.7.0 └☑ jinja2 3.1.2

Numerical (optional) ├☑ sklearn 1.3.0 ├☑ pandas 2.1.1 └☐ unavailable numba, nibabel, nilearn, dipy, openmeeg, cupy

Visualization (optional) ├☑ qtpy 2.2.0 (PyQt5=5.15.2) ├☑ ipympl 0.8.7 ├☑ pyqtgraph 0.13.3 ├☑ mne-qt-browser 0.6.0 ├☑ ipywidgets 7.6.5 └☐ unavailable pyvista, pyvistaqt, vtk, trame_client, trame_server, trame_vtk, trame_vuetify

Ecosystem (optional) ├☑ mne-connectivity 0.6.0dev0 └☐ unavailable mne-bids, mne-nirs, mne-features, mne-icalabel, mne-bids-pipeline

adam2392 commented 8 months ago

@tsbinns any idea on this?

tsbinns commented 8 months ago

@rythorpe Thanks for bringing this to my attention. You are not doing anything wrong, this was a mistake on my part.

@adam2392 This is something I fixed in #142. If you like, I can submit a fresh PR with a hotfix for this specific problem.

adam2392 commented 8 months ago

Ah I see. How close are you to something that's reviewable from the disc. we had in #142?

@rythorpe perhaps if you're comfortable you can checkout the change in #142? The support for multivariate-connectivity was recently added by @tsbinns and we're doing a refactor before we finalize this feature in v0.6.

tsbinns commented 8 months ago

@adam2392 I think one matter requires further discussion, so I'm not sure how quickly the fix there would be merged. I will post a more detailed update there now.

rythorpe commented 8 months ago

Thanks for the reply! I can just fetch the branch on #142 for now.

tsbinns commented 8 months ago

@rythorpe No problem! Feel free to write me again if anything is unclear.

rythorpe commented 8 months ago

Are there any major API changes I should know about @tsbinns? It looks like your updated docstring says something about providing nested arrays for the seeds and targets for indices in spectral_connectivity_time? I'm guessing that means I should now use e.g.

indices = (np.array([[0]]),
           np.array([[2]]))

instead of

indices = (np.array([0]),
           np.array([2]))

to specify that I want to calculate GC from the 1st to 3rd channel?

tsbinns commented 8 months ago

@rythorpe Exactly, the goal of #142 is to add support for computing connectivity on multiple multivar. connections in the same function call. This requires indexing seeds and targets for a given connection as a nested array. Like you say, this is correct:

indices = (np.array([[0]]),
           np.array([[2]]))

My latest refactoring attempt did not yet touch spectral_connectivity_time, so you can still use this to call the multivariate methods.

It is also now possible to compute connectivity where the number of seed channels != number of target channels, which is explained more in this example.

rythorpe commented 8 months ago

Got it, thanks @tsbinns!

rythorpe commented 8 months ago

Fixed in #142. Thanks @tsbinns!