neuromodulation / py_neuromodulation

Real-time analysis of intracranial neurophysiology recordings.
https://neuromodulation.github.io/py_neuromodulation/
MIT License
42 stars 9 forks source link

MNE connectivity 2x speed-up #346

Closed toni-neurosc closed 3 months ago

toni-neurosc commented 3 months ago

I know you're expecting news on the GUI, but since I wanted to present also about the speed optimizations next week and I know that some people involved in the MNE package might be around the lab I did some investigation in one of the 3 features that were still weighing us down too much for real-time processing: MNE connectivity (the other 2 being Bispectra and FOOOF)

So I realized that Epochs.init() and RawArray.init() were taking so much computation time, that I wondered whether running the constructors every time was even doing anything. So I went through all the lines of code taking notes and figured, no, the 200-300 lines of constructor for each class are doing nothing for us at least with our current parameters. So I basically bypassed those by just replacing the Numpy array that they were holding.

I also noticed that the function spectral_connectivity_epochs was calling another function, add_annotations_to_metadata, that was doing some serious outsourcing to Pandas, which not only was super-slow, but was also doing 0 changes to the Epochs.annotations property that it was manipulating and that was also never used after that for anything. So I checked the conditional if statement that triggered the call to add_annotations_to_metadata and added a little hack that makes the condition fail so it's not called. But it's super-weird that I need to do that to get reasonable behavior from spectral_connectivity_epochs.

Anyway now MNE_connectivity calc_feature() takes less than 47% of the time it took before and now lands in terms of performance right between Sharpwaves and Bursts. So now there's only Bispectra to check out (FOOOF I have already investigated and there's no room for optimization there except if I go down some serious linear algebra rabbit-hole that I don't have time for right now, but that is quite exciting tbh)

In fact, it could even be faster with the following change that I did for the mean calculation at the end of the function:

        for fband_idx, fband in enumerate(self.fbands):
            fband_mean = np.mean(dat_conn[:, self.fband_ranges[fband_idx]], axis=1)
            for conn in np.arange(dat_conn.shape[0]):
                key = "_".join(["ch1", self.method, str(conn), fband])
                features_compute[key] = fband_mean[conn]

So what that does is, computing the mean for all 4 "conn" subsets of data at the same time, so should make np.mean 4 times faster for this particular case. But when I did that, I got a couple output values that were 0.999999999.... instead of 1.0. Should be the same in principle, but after normalization, I get a couple -3 values in the output festures CSV insetad of 0's. I have no idea what's going on, might be an issue with normalization? I think it deserves a check, because 0.9999999999 or 1.0 should not make that big a change in the final result.

Also, I have another question about the MNE-connectivity feature: how come the channels are hard-coded in the calc_feature function? I'm talking about this argument to spectral_connectivity_epochs: indices=(np.array([0, 0, 1, 1]), np.array([2, 3, 2, 3])), which according to the docstring is:

If a bivariate method is called, each array for the seeds
        and targets should contain the channel indices for each bivariate
        connection

So what's up with that? What if the data has less channels? And if the seed channels are [0, 0, 1, 1], how come the outputs are coded as ch1 for all output features?: key = "_".join(["ch1", self.method, str(conn), fband])

I have the feeling I'm either missing something or the code is.

PS: Here the dev-diary notes I took about the changes I made:

calc_feature:
    - Do the RawArray initialization inside get_epoch_data, to avoid having to call .get_data()
    - Do we need to create_info everytime? I would assume the info is always the same
        ○ it is indeed -> create only once
    - Bypass make_fixed_length_events by creating the events ourselves, as we don't make use of almost any of the parameters except epoch_length

Epochs constructor:
    - Why deepcopy info? info is never written to I think. Threading concerns? In any case, we have created info just before, so… no point. 
    - Arguments passed:
                self.raw_array,
                events=events,
                event_id={"rest": 1},
                tmin=0,
                tmax=epoch_length,
                baseline=None,
                reject_by_annotation=True, # Default
                verbose=False,
    - What is it doing?
        ○ BaseRaw check not needed
        ○ no annotations passed, persistent
        ○ proj is false and persistent
        ○ reject_by_annotation is True, but persistent
        ○ sfreq is persistent
        ○ events is NOT PERSISTENT, needs replacement, is passed to BaseEpochs constructor
    - BaseEpochs constructor
        ○ _ensure_events, no need, we're creating them fine
        ○ selected is useless cause it will select all anyway
        ○ self.drop_log is useless, we're not dropping anything
        ○ _handle_event_repeated is doing nothing otherwise it would error
        ○ metadata is None, we're not doing anything with it
        ○ we're not passing detrend
        ○ self._raw is the raw data (line 584), first line that is relevant
        ○ _picks_to_idx basically returns channel indices, seems persistent 
        ○ data is none, these are set:
                    if data is None:
                        self.preload = False
                        self._data = None
                        self._do_baseline = True
        ○ this is set: self._offset = None
        ○ self._raw_times is persistent
        ○ reject_tmin and reject_tmax are None, these parts are skipped
        ○ self.decimate(decim) does nothing
            § calls _set_times which changes self._times_readonly-> but no changes in times_readonly
        ○ _check_baseline is doing nothing as well since Baseline is None
        ○ More sets and _reject_setup called, but I'm pretty sure it does nothing: 
                self.reject = None
                self.flat = None
        ○ Next comes proj section but it does nothing I think
        ○ preload_at_end is false so load_data is not called
        ○ _check_consistency() is unnecessary
        ○ set_annotations is unnecessary

    - CONCLUSION: Epochs object can be re-used by just setting self._raw = raw

NOTE: Epochs would need reinitialization if settings, sfreq or channels change, but I imagine the whole stream would be re-initialized in that case.

spectral_connectivity_epochs()
    - _assemble_spectral_params takes the most time (22%) 

    - the SpectralConnectivity object constructor also takes a lot of time (10%)
    - add_annotations_to_metadata  takes very long (11%) but does nothing.
        ○ to prevent the function from being called, annots_in_metadata needs to be True.
        ○ For that, all names in ["annot_onset", "annot_duration", "annot_description"] need to be missing from metadata.columns

    - _epoch_spectral_connectivity (6%)
    - _get_n_epochs loops over sth, takes (4.5%)
timonmerk commented 3 months ago

Hey @toni-neurosc, thanks for this PR. I think it would be best to discuss the PARRM and mne-connectivity optimization with @tsbinns, since he's a developer of both tools. They were probably not developed for real-time use, so I assume that there is room for speed improvement.

tsbinns commented 3 months ago

I realise this is closed, but just some comments. Can move to a new issue if preferred.


@toni-neurosc I also noticed that the function spectral_connectivity_epochs was calling another function, add_annotations_to_metadata, that was doing some serious outsourcing to Pandas, which not only was super-slow, but was also doing 0 changes to the Epochs.annotations property that it was manipulating and that was also never used after that for anything. So I checked the conditional if statement that triggered the call to add_annotations_to_metadata and added a little hack that makes the condition fail so it's not called. But it's super-weird that I need to do that to get reasonable behavior from spectral_connectivity_epochs.

add_annotations_to_metadata() is a method of the Epochs class. It does not change the annotations attr, but rather adds this annotation information to the metadata attr. This metadata is then passed through and stored in the returned connectivity object, I imagine as a way of preserving info about the data from which connectivity was derived.

If you have annotations in your Epochs and you have not already added these to metadata, spectral_connectivity_epochs() will do this for you.

There are two less hacky solutions than modifying the private Epochs._metadata attr:

  1. Set Epochs.annotations=None before computing connectivity (add_annotations_to_metadata() checks this first, and if so, does not make a new DataFrame)
  2. Pass in the epoch data as an array to spectral_connectivity_epochs() and avoid this logic of checking for annotations/metadata

@toni-neurosc Also, I have another question about the MNE-connectivity feature: how come the channels are hard-coded in the calc_feature function? I'm talking about this argument to spectral_connectivity_epochs: indices=(np.array([0, 0, 1, 1]), np.array([2, 3, 2, 3]))

Yeah, I would have thought this should also be a parameter the end-user can change, otherwise you are only computing connectivity from this specific combination of channels.


@timonmerk They were probably not developed for real-time use, so I assume that there is room for speed improvement.

Perhaps there are some optimisations that could be made within MNE-Connectivity, but a lot of the overhead comes from thing's like _assemble_spectral_params, instantiating the connectivity object, and a bunch of other internal checks that are being re-run every time you pass in data for a new window as @toni-neurosc points out. These types of things could be omitted for a shorter run-time, but that's not gonna come from spectral_connectivity_epochs(), it would require a custom implementation of the pipeline.


When checking the diff I also noticed this:

https://github.com/neuromodulation/py_neuromodulation/blob/dc285675a3556a50fd48e03ee0f9d6fe72c35ddb/py_neuromodulation/nm_mne_connectivity.py#L87-L91

spectral_connectivity_epochs() should work as long as the data has shape (n_epochs x n_channels x n_times), even if n_epochs=1. Is that not the case?