mne-tools / mne-connectivity

Connectivity algorithms that leverage the MNE-Python API.
https://mne.tools/mne-connectivity/dev/index.html
BSD 3-Clause "New" or "Revised" License
66 stars 34 forks source link

[GSOC] implement example of state-space model for connectivity #100

Open jadrew43 opened 2 years ago

jadrew43 commented 2 years ago

PR Description

Google Summer of Code (2022) project

Closes #99

WIP: Linear Dynamic System (state-space model using EM algorithm to find autoregressive coefficients) to infer functional connectivity by interpreting autoregressive coefficients as connectivity strength. The model uses M/EEG data as input, and outputs time-varying autoregressive coefficients for source space labels.

Completed during GSoC

data_path = mne.datasets.sample.data_path()

raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60)
events = mne.find_events(raw, ...)

event_dict = {...}
epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=0.7, event_id=event_dict,
                    preload=True).pick_types(meg=True,eeg=True)

fwd_fname = sample_folder / '....'
fwd = mne.read_forward_solution(fwd_fname)

cov_fname = sample_folder / 'sample_audvis-cov.fif'
cov = mne.read_cov(cov_fname)

label_names = ['Aud-lh', 'Aud-rh', 'Vis-lh', 'Vis-rh']
labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label',
                          subject='sample') for label in label_names]

model = LDS(lam0=0, lam1=100)
model.add_subject('sample', condition, epochs, labels, fwd, cov)
model.fit(...)
model.fit()
At = model.A
assert At.shape == (len(labels), len(labels), len(epochs.times))

Check-out this link to see my weekly progress. All of the code in this PR is new to MNE-Python's repositories.

Todo after GSoC

Merge checklist

Maintainer, please confirm the following before merging:

cbrnr commented 2 years ago

Hello! Just wanted to say hi to include myself in the loop 😄. Could someone quickly explain what the aim of this PR is? Is it to add VAR model based connectivity estimates? State-space sounds like it's implemented as a Kalman filter. At the risk of sounding repetitive, but could we look at https://github.com/scot-dev/scot to see if anything could be re-used here? I've implemented least squares VAR estimation (optionally with regularization) to compute several popular (directed) connectivity measures (for a list see https://scot-dev.github.io/scot-doc/api/scot/scot.html#module-scot.connectivity).

jadrew43 commented 2 years ago

Hi @cbrnr you are correct that this PR aims to implement a Kalman filter using an AR model to measure connectivity. Reviewing SCoT is still on my to-do list, thanks for the reminder.

larsoner commented 2 years ago

From a quick chat with Jordan, here is what we fleshed out a bit based on my suggestion for the public API:

Internal implementation sketch and public API ``` # Internal code class MEGLDS: def __init__(self, ...): ... self._subject_data = dict() def add_subject(subject, forward, cov, ...): self._subject_data[subject] = dict() self._subject_data[subject]['G'] = something(forward, ...) self._subject_data[subject]['C'] = something_else(cov, ...) def fit(self): Gs = np.array([val['G'] for val in self._subject_data.values()]) ... # User API should only have: # - subjects_dir, then per-subject: # - subject # - Forward # - Covariance # - Epochs # - list of Label data_path = mne.datasets.sample.data_path() subjects_dir = data_path / 'subjects' model = MEGLDS(lambda0, lambda1, ...) forward = mne.read_forward_solution(...) cov = mne.read_covariance(...) labels = list() label_names = ('Aud.lh', 'Aud.rh', 'Visual.lh', 'Visual.rh') for name in label_names: labels.append(mne.read_label(subjects_dir / 'sample' / 'labels' / name)) model.add_subject('sample', subjects_dir=subjects_dir, labels=labels, forward=forward, cov=cov) ... model.fit() At = model.A assert At.shape == (len(labels), len(labels), len(epochs.times)) ```

EDIT: I resolved the conversations above where I talk about this since I think it's general enough to discuss in this main thread rather than inline

jadrew43 commented 2 years ago

examples/state_space_connectivity.py is functioning properly on my machine with the output depicted below. I think for a single subject, these results look good. Along the diagonal, values are close to 1, as expected for a computation similar to an autocorrelation. For the condition auditory/left there seems to be a connection from Aud-lh to Aud-rh as seen by the non-zero values in graph [0,1]. I expect measurements to be less noisy when running for a large number of subjects. Please run CI checks.

x-axis: time (seconds); y-axis: connectivity strength (autoregressive coefficients)

image

jadrew43 commented 2 years ago

I am at a research conference for the next 7 days. When I return I have the following to-do:

Looking forward to your feedback!

jadrew43 commented 2 years ago

I am currently working to incorporate an old dataset to see if these scripts produce the expected results.

jadrew43 commented 2 years ago

Hey @adam2392 one of the CI errors is due to autograd: ModuleNotFoundError: No module named 'autograd'. I'd really like to move to the next step of the proposal and work to integrate jax (to replace autograd) later on in order to keep pace with the summer milestones. Is it alright if I install autograd in the dependencies for mnedev (assuming that will fix the error) for now and work to integrate jax later on?

adam2392 commented 2 years ago

Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on.

Perhaps we can start another sep issue to track migration to jax later on?

jadrew43 commented 2 years ago

Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on.

Perhaps we can start another sep issue to track migration to jax later on?

@adam2392 Cool! I do not where to add autograd as an optional Dev dependency. Should it be installed in mnedev?

Where should I leave the comment that it can be replaced with jax later on? Just start a new issue on Github with the suggestion?

adam2392 commented 2 years ago

Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on. Perhaps we can start another sep issue to track migration to jax later on?

@adam2392 Cool! I do not where to add autograd as an optional Dev dependency. Should it be installed in mnedev?

You can add it here: https://github.com/mne-tools/mne-connectivity/blob/main/requirements_testing.txt and then just add a comment e.g.

...
# TODO: we will replace this with Jax
autograd

Then when you import autograd, you should maybe import it within the functions you want to use and not at the top of the Python file. That way there is no error when someone tries to run import mne_connectivity if they don't have autograd/jax installed. For example, this function inside MNE-Python needs pyqt, but imports it within the function so mne still works if the user only has numpy/scipy installed.

Where should I leave the comment that it can be replaced with jax later on? Just start a new issue on Github with the suggestion?

Yeah I'll create the GH issue.

jadrew43 commented 2 years ago

Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on. Perhaps we can start another sep issue to track migration to jax later on?

@adam2392 Cool! I do not where to add autograd as an optional Dev dependency. Should it be installed in mnedev?

You can add it here: https://github.com/mne-tools/mne-connectivity/blob/main/requirements_testing.txt and then just add a comment e.g.

...
# TODO: we will replace this with Jax
autograd

Then when you import autograd, you should maybe import it within the functions you want to use and not at the top of the Python file. That way there is no error when someone tries to run import mne_connectivity if they don't have autograd/jax installed. For example, this function inside MNE-Python needs pyqt, but imports it within the function so mne still works if the user only has numpy/scipy installed.

Where should I leave the comment that it can be replaced with jax later on? Just start a new issue on Github with the suggestion?

Yeah I'll create the GH issue.

Ok I think that's complete. Thanks!

jadrew43 commented 1 year ago

Update: I got the code to work for an old dataset, however the results are not what I expected. I am currently working to get the sample data (much smaller than the old dataset) to work on my original model. This output will be my new ground truth and I can work iteratively to make sure each edit I make to the original model to conform it the MNE-Python API standards produces the same output. Lesson learned - have a simple ground truth to work with from the beginning :)

jadrew43 commented 1 year ago

Here is the output of my original model using the sample dataset for the auditory/left condition for the 12 ROIs from a different dataset. My next step is to get this same output in the mnedev environment. Then I will work to use only the 4 ROIs commonly used with the sample dataset. Then piece by piece I will recreate the API. image

larsoner commented 1 year ago

Excellent!

A good next step is to make sure all random seeds can be set such that if you run this again you get the exact same output (to numerical precision at least). Then you can save the At result to disk now and compare to it each time you replace some piece of code

jadrew43 commented 1 year ago

I have changed the PCA method from being based on the rank of the matrix to a method based on explaining 99% of the variance. This method allows the fitting of the model to run much faster as it produces 147 principal components vs 360 components produced from the rank method. The model output is noticeably but not extremely different. My next step is to perform the processing steps using the 4 labels provided in the sample dataset, which should reduce processing time even further. All processing was completed within the mnedev conda environment. image

jadrew43 commented 1 year ago

Processing completed with 4 labels from sample.

image
jadrew43 commented 1 year ago

Bootstrapping of epochs, and PCA of epochs.get_data() and forward matrices completed in API. Model fitting completed in command line model. Output from API (LDS) compared to command line model (MEGLDS) are not identical but extremely similar. image image

larsoner commented 1 year ago

Nice!

It would be good to know what the differences are that make them not identical, but really if this version of the API works on our UW data as well (maybe even just for one subject?) then I'd say you could use this as the "ground truth" for correctness of additional changes!

jadrew43 commented 1 year ago

@larsoner Can you look at megssm/mne_util.py L114. Am I using the scaler correctly? Because the results do not agree with the original _scale_sensor_data (L140). Thanks.

I'll get to the CI checks first thing Monday!

adam2392 commented 1 year ago

Hi @jadrew43 and @larsoner any help needed to review code / look at prelim results here? Feel free to lmk where I can help!

larsoner commented 1 year ago

IIRC code is still WIP / needs to be systematically converted to MNE conventions, but some bugs have been found along the way which is good!