mne-tools / mne-connectivity

Connectivity algorithms that leverage the MNE-Python API.
https://mne.tools/mne-connectivity/dev/index.html
BSD 3-Clause "New" or "Revised" License
68 stars 34 forks source link

Within-epoch surrogate generation #251

Open seqasim opened 1 day ago

seqasim commented 1 day ago

Describe the problem

It is difficult to interpret many connectivity estimates without comparison to "null" surrogate data. As stated in Aru et al., 2015 : "A suitable surrogate construction should only destroy the specific cyclo-stationarities related to the hypothesized... effect, while keeping all the unspecific non-stationarities and non-linearities of the original data."

Describe your solution

It would make sense to perform trial/epoch level scrambling so that non-stationarities and phase distributions within a trial/epoch are conserved. Tensorpac achieves this with a simple function, which I have edited slightly:

def swap_time_blocks(data, random_state=None):

    """Compute surrogates by swapping time blocks.
    This function cuts the timeseries at a random time point. Then, both time
    blocks are swapped.
    Parameters
    ----------
    data : array_like
        Array of shape (n_chan, ..., n_times).
    random_state : int | None
        Fix the random state of the machine for reproducible results.
    Returns
    -------
    surr : array_like
        Swapped timeseries to use to compute the distribution of
        permutations
    References
    ----------
    Source: https://www.sciencedirect.com/science/article/pii/S0959438814001640
    """

    if random_state is None:
        random_state = int(np.random.randint(0, 10000, size=1))
    rnd = np.random.RandomState(random_state)

    # get the minimum / maximum shift
    min_shift, max_shift = 1, None
    if not isinstance(max_shift, (int, float)):
        max_shift = data.shape[-1]
    # random cutting point along time axis
    cut_at = rnd.randint(min_shift, max_shift, (1,))
    # split amplitude across time into two parts
    surr = np.array_split(data, cut_at, axis=-1)
    # revered elements
    surr.reverse()

    return np.concatenate(surr, axis=-1)

Example usage with mne data:

    data = np.swapaxes(mne_data.get_data(copy=False), 0, 1) # swap so now it's chan, events, times 
    surr_dat = np.zeros_like(data) # allocate space for the surrogate channels 
    for ix, ch_dat in enumerate(data): # apply the same swap to every event in a channel, but differ between channels 
        surr_ch = swap_time_blocks(ch_dat, random_state=None)
        surr_dat[ix, :, :] = surr_ch
    surr_dat = np.swapaxes(surr_dat, 0, 1) # swap back so it's events, chan, times 
    # make a new EpochArray from it
    surr_mne = mne.EpochsArray(surr_dat, 
                mne_data.info, 
                tmin=mne_data.tmin, 
                events = mne_data.events, 
                event_id = mne_data.event_id)
tsbinns commented 1 day ago

Hi @seqasim,

I wonder how much this need is addressed with this PR I wrote over the summer: https://github.com/mne-tools/mne-connectivity/pull/223

It has a similar purpose to generate some surrogate estimates of null connectivity based on the procedure (briefly) described in these publications: https://www.sciencedirect.com/science/article/pii/S1053811923003695; https://www.biorxiv.org/content/10.1101/2023.10.26.564193v1.abstract

This example demonstrates things: https://output.circle-artifacts.com/output/job/bafe351e-3009-46ea-92ce-863de6768196/artifacts/0/dev/auto_examples/surrogate_connectivity.html#sphx-glr-auto-examples-surrogate-connectivity-py

Cheers!

seqasim commented 1 day ago

Yes, I saw this nice addition! But I think the example you posted highlights the need for intra-trial shuffling:

"Critically, evoked data contains a temporal structure that is consistent across epochs, and thus shuffling epochs across channels will fail to adequately disrupt the covariance structure."

I believe that cutting the data within the trial, as I propose, would deal with this issue?

tsbinns commented 1 day ago

Cutting just once within the trial is not something we had looked into before.

I adapted the code you wrote into a function that performs this shuffling multiple times and incorporated it into the example I linked to: image Orange is the connectivity for the evoked response, red is the connectivity for the surrogate with the proposed within-epoch shuffling, and blue is the connectivity from the pre-stimulus surrogate with the existing procedure in that PR.

It seems there is still a lot of residual connectivity information preserved using the within-epoch shuffling. Perhaps also cutting the data independently for each epoch would reduce this. Making additional cuts within each epoch could be tricky, as then you move towards creating white noise.

I don't think it's due to an error when I adapted the code, but here it is in case I missed something:

def make_surrogate_data_alt(
    data, n_shuffles=1000, rng_seed=None, return_generator=True
):
    """Create surrogate data for a null hypothesis of connectivity."""
    surrogate = _shuffle_within_epochs(data, n_shuffles, rng_seed)
    if not return_generator:
        surrogate = [shuffle for shuffle in surrogate]

    return surrogate

def _shuffle_within_epochs(data, n_shuffles, rng_seed):
    """Shuffle within epochs in data."""
    from mne import EpochsArray

    data_arr = data.get_data(copy=True)
    rng = np.random.default_rng(rng_seed)
    for _ in range(n_shuffles):
        surr_arr = np.zeros_like(data_arr)
        for ch_idx in range(data_arr.shape[1]):
            surr_arr[:, ch_idx] = _swap_time_blocks(data_arr[:, ch_idx], rng)
        yield EpochsArray(surr_arr, info=data.info)

def _swap_time_blocks(data, rng):
    """Swap time blocks in data."""
    min_shift = 1
    max_shift = data.shape[-1]
    cut_at = rng.integers(min_shift, max_shift, 1)
    surr = np.array_split(data, cut_at, axis=-1)
    surr.reverse()
    return np.concatenate(surr, axis=-1)
seqasim commented 1 day ago

Wow, thanks for adapting this so quickly! Lovely to see a head-to-head comparison.

Based on this real-data example, though, it's not clear to me whether the pre-stimulus surrogate (blue) will give false positive assessments or my post-stimulus surrogate (red) will give false negative assessments - maybe the true task-evoked effect in this MEG data is truly in the 6-12 Hz range?

Would it perhaps be better to simulate data with varying levels of frequency-specific power and true connectivity and test which surrogate method recovers ground truth most effectively?

Or at least it might make sense to provide users with all three options: pre-stimulus epoch shuffle (most permissive), post-stimulus epoch shuffle (medium permissive), post-stimulus cut and swap (least permissive) ?

tsbinns commented 1 day ago

Would it perhaps be better to simulate data with varying levels of frequency-specific power and true connectivity and test which surrogate method recovers ground truth most effectively?

Definitely. For the between channel epoch shuffling, this was tested using data from the make_surrogate_connectivity() function, which uses a pretty standard approach of generating spectral connectivity data. However, the data from this should be considered more as resting state data. This approach would need to be adapted to generate connectivity for evoked data.

Still, even for resting state data the within epoch shuffling would need to be able to identify this baseline.

seqasim commented 21 hours ago

Ah I see. I think you're right about cutting the data independently for every epoch for evoked data. I modified your code to do so. Could you possibly test this with the example you provided above?

import numpy as np
from mne import EpochsArray

def make_surrogate_data_alt(data, n_shuffles=1000, rng_seed=None, return_generator=True):
    """Create surrogate data for a null hypothesis of connectivity."""
    surrogate = _shuffle_within_epochs(data, n_shuffles, rng_seed)
    if not return_generator:
        surrogate = [shuffle for shuffle in surrogate]
    return surrogate

def _shuffle_within_epochs(data, n_shuffles, rng_seed):
    """Shuffle within epochs in data."""
    data_arr = data.get_data(copy=True)
    rng = np.random.default_rng(rng_seed)
    for _ in range(n_shuffles):
        surr_arr = np.zeros_like(data_arr)
        cutpoints = rng.integers(1, data_arr.shape[-1], (data_arr.shape[0], data_arr.shape[1]))
        for ev_idx in range(data_arr.shape[0]):
            for ch_idx in range(data_arr.shape[1]):
                surr_arr[ev_idx, ch_idx] = _swap_time_blocks(data_arr[ev_idx, ch_idx], cutpoints[ev_idx, ch_idx])
        yield EpochsArray(surr_arr, info=data.info)

def _swap_time_blocks(data, cut_at):
    """Swap time blocks in data at a given cutpoint."""
    surr = np.array_split(data, [cut_at], axis=-1)
    surr.reverse()
    return np.concatenate(surr, axis=-1)

Note that cutpoints should on contain a unique cut for each event, and and each channel. This should preserve within-trial dynamics, but destroy across-trial residual connectivity. Thoughts?

tsbinns commented 18 hours ago

This is what I get with the code you modified (again, blue is existing method; red is suggested method): image The additional independent shuffling for each epoch definitely helps.

While this new approach is somewhat simpler for working with evoked data, it is more computationally demanding since you need to compute the Fourier coefficients for each shuffle of the timeseries, whereas the approach in the existing PR allows you to compute the Fourier coefficients once and then shuffle these.

I would be interested to hear the thoughts of @larsoner, @wmvanvliet, and @drammock, and if there's a way both approaches could be made available.