mne-tools / mne-connectivity

Connectivity algorithms that leverage the MNE-Python API.
https://mne.tools/mne-connectivity/dev/index.html
BSD 3-Clause "New" or "Revised" License
66 stars 34 forks source link

[BUG] Enable averaging connectivity over tapers instead of averaging PSD over tapers #84

Closed adam2392 closed 1 year ago

adam2392 commented 2 years ago

As of v1.0 MNE-Python, there was a bug fix to not average the tfr in multitaper analysis unless the user specified. So using that in mne-connectivity, we need to compute the power in multiple tapers, and then compute connectivity and then average across tapers.

I see. Thanks for this comment. Perhaps that is why there are some issues in #73

_Originally posted by @adam2392 in https://github.com/mne-tools/mne-connectivity/pull/82#discussion_r805027992_

adam2392 commented 2 years ago

@Div12345 and @avoide

https://github.com/mne-tools/mne-connectivity/blob/711ccce68545c14dd9aab40213efe3e1f06e18cf/mne_connectivity/spectral/time.py#L269-L282

These lines possibly are contributing to the discrepancy we see. Before, we were averaging multi taper analysis across tapers before we computed connectivity. We should instead be computing connectivity for all tapers and then averaging.

Div12345 commented 2 years ago

@adam2392 Is there a regression test that already covers this? Because I commented out 276,277 and uncommented 281,282 and all tests seem to pass.. If not, I could try writing a regression test and checking it out..

adam2392 commented 2 years ago

No unfortunately there is not. The issue that made it not a "trivial fix" is that the conn_func needs to be applied for each taper, but right now the API doesn't support doing something like that.

I would first identify what pieces of code need to change and propose a quick pseudocode here to see if it makes sense.

ruuskas commented 1 year ago

I'm currently working on this issue and probably have already solved it, barring testing.

adam2392 commented 1 year ago

Sure sounds good! @ruuskas lmk if there's a PR you would like help iterating on. I can suggest help for testing if needed.

ruuskas commented 1 year ago

Sure @adam2392 , I would like some help with the upcoming PR. I'm in the process of making the API more similar to spectral_connectivity_epochs and also fixed the issue with PLV computation, mentioned in #90 and #73.

adam2392 commented 1 year ago

Awesome excited to see the PR! Feel free to start it early as a "WIP" Draft.

ruuskas commented 1 year ago

@adam2392 I have submitted the PR. Could you have a look when you have time?

ruuskas commented 1 year ago

Hi @adam2392, @larsoner. I came back to this after finding some quite peculiar behaviour with the connectivity function. Are you sure that this conclusion is correct?

 # TODO: This is wrong -- it averages in the complex domain (over tapers). 
 # What it *should* do is compute the conn for each taper, then average 
 # (see below). 

I did some testing with this, and the connectivity results do not appear smooth at all. See plots below. Most importantly, there's a dip in connectivity at 10Hz, which is the only actual coupling frequency in the simulated data. Morlet convolution for comparison.

multitaper_plv_10hz

morlet_10Hz_plv

These were computed with the implementation found in #104.

I'm just guessing, but perhaps the cross-spectral density should be averaged and connectivity computed from that instead of averaging connectivity matrices? I will go through the code tomorrow to check for other potential bugs that could cause this.

larsoner commented 1 year ago

Are you sure that this conclusion is correct:

In principle yes it seems like a reasonable conclusion to me.

I agree that the first plot looks buggy, though -- I wouldn't expect it to look like that. Have you tried using multitaper with exactly one taper? That should look more like the morlet case. Then I'd try with one taper again, but this time just using the second taper instead of the first. (This will be a pain to code potentially but it should be doable.) Both of these should be smooth, similar to the Morlet case. If they are, the aggregate should also be smooth. From there you could use exactly two tapers (first and second) and see that the result is smooth. If it isn't, then it seems like there is some bug in the code that combines these values. If it is smooth, I'd gradually increase the number of tapers until I started seeing problems, then try to understand why they arise...

ruuskas commented 1 year ago

Hi @larsoner. I tried doing what you suggested. Indeed, the second taper is not smooth, nor is the third. multitaper-2tapers-first-only multitaper-2tapers-second-only multitaper-3tapers-third-only

I also plotted the (absolute value of the) CSD computed from the output of tfr_array_multitaper with output='complex'.

csd_for_tapers

The simulated data in this case is a 10Hz sinusoid with amplitude 0.3 and random but constant phase in each channel, with uniformly random additive noise.

ruuskas commented 1 year ago

Hi @larsoner, @adam2392!

Here's a comparison between the three different ways to average the tapers. Namely, averaging connectivity matrices, averaging the complex signal, and averaging the CSD. I plotted five different connectivity metrics against frequency. The underlying signal in the simulated channels is a 10Hz phase-locked sinusoid with additive noise.

It looks like averaging the CSD would give the smoothest result. This conclusion appears to be supported by the formulation in Wikipedia. WDYT?

multitaper_taper_averaging

larsoner commented 1 year ago

In the "average connectivity" case, is it a weighted average? I think it needs to be in order to be comparable. A naive/simple average would unfairly bias the results to the tapes with less power compared to the average-complex and average-CSD cases (which implicitly include the weighting IIUC).

larsoner commented 1 year ago

... but in any case, from thinking about it briefly it seems like the CSD case could very well in principle be the best choice. Any spectral modification imposed by the window function will occur on both the X and Y signals, but this should go away during CSD computation because it effectively only looks at relative amplitude and phase (compared to the average-complex-signals case where any phase shifts across tapers will be combined/averaged rather than removed before averaging).

And as you point out this quote makes it clear:

The multitaper cross-spectral estimator between channel l and m is the average of K direct cross-spectral estimators between the same pair of channels (l and m)

So +1 for using averaged PSD. I'm still curious about the weighted vs unweighted average in the above plots, but even if these are unweighted, I'm guessing the weighted average will not look as reasonable (or be as correct in principle!) as the average CSD case.

drammock commented 1 year ago

@ruuskas can you plot the tapers themselves to see if they look right? I would expect something like this (blue, orange, green are the first 3):

Screenshot 2022-11-08 at 10-17-31 lab - JupyterLab

ruuskas commented 1 year ago

@ruuskas can you plot the tapers themselves to see if they look right? I would expect something like this (blue, orange, green are the first 3):

Screenshot 2022-11-08 at 10-17-31 lab - JupyterLab

Not easily, but that part is inside mne.time_frequency.tfr_array_multitaper so I would assume it to be correct.

ruuskas commented 1 year ago

In the "average connectivity" case, is it a weighted average? I think it needs to be in order to be comparable. A naive/simple average would unfairly bias the results to the tapes with less power compared to the average-complex and average-CSD cases (which implicitly include the weighting IIUC).

This was just the simple average without any weighting. How should one go about computing the weights? I would assume that the weights should be frequency specific.

So +1 for using averaged PSD. I'm still curious about the weighted vs unweighted average in the above plots, but even if these are unweighted, I'm guessing the weighted average will not look as reasonable (or be as correct in principle!) as the average CSD case.

I see. I could add this in the PR where I'm already suggesting several changes to spectral_connectivity_time.

larsoner commented 1 year ago

I would assume that the weights should be frequency specific.

Usually taper-specific. They are related to the eigenvalues from the eigen decomp that computes the tapers, see for example

https://github.com/mne-tools/mne-python/blob/f26528d78764c83f754873c40f17e40d5eb08d2d/mne/time_frequency/multitaper.py#L128

We should already use them in MNE-Python when computing PSDs. In addition to potentially being useful for your comparison above, we should probably use them when averaging the CSDs across tapers just like we do for PSDs.

ruuskas commented 1 year ago

Usually taper-specific. They are related to the eigenvalues from the eigen decomp that computes the tapers, see for example

We should already use them in MNE-Python when computing PSDs. In addition to potentially being useful for your comparison above, we should probably use them when averaging the CSDs across tapers just like we do for PSDs.

Just to be sure we're on the same page, the results above were computed using the complex signal output from tfr_array_multitaper, not using the CSD function csd_array_multitaper from MNE-Python. Do you think the weighting should be applied to the output of tfr_array_multitaper?

larsoner commented 1 year ago

I think any time you combine estimates from multiple tapers, you probably want to use a weighted average rather than a naive/uniform average. So assuming your tfr_array_multitaper returns estimates at the level of individual tapers, once you go to combine them (later) at the CSD stage, it seems like you should use the same weights that we use when computing a PSD.

ruuskas commented 1 year ago

In mne.time_frequency.tfr._make_dpss, the tapers are normalized by their norm. I wonder if this has any significance w.r.t. the weighting with the sqrts of the eigenvalues.

# Get dpss tapers
tapers, conc = dpss_windows(t.shape[0], time_bandwidth / 2.,
                                                n_taps, sym=False)

Wk = oscillation * tapers[m]
if zero_mean:  # to make it zero mean
    real_offset = Wk.mean()
    Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())

I will try computing the weights and post the connectivity plots here anyway.

EDIT: Tagging @agramfort here as well as you seem to be one of the authors of tfr.py.

ruuskas commented 1 year ago

Here is the figure including the weighted average of CSD. The weights are very close to one, rendering their effect miniscule. The code is here.

multitaper_taper_averaging

larsoner commented 1 year ago

+1 for using the weighted average since I think in principle it's more correct, and consistent with how we do PSDs

ruuskas commented 1 year ago

I will add this to #104.