mne-tools / mne-python

MNE: Magnetoencephalography (MEG) and Electroencephalography (EEG) in Python
https://mne.tools
BSD 3-Clause "New" or "Revised" License
2.7k stars 1.31k forks source link

Helper function to find source space indices for label vertices #8162

Closed olafhauk closed 4 years ago

olafhauk commented 4 years ago

I'll need a function (e.g. "_get_source_space_vertices()") that outputs the indices to source space vertices for the Freesurfer vertices in labels (labels.vertices). For example, I'd like to use these indices to select columns from the leadfield (leadfield[:, new_vertices]). There seems to be code from _prepare_label_extraction() (source_estimate.py, for info below) that does this. But I am not familiar with mixed source spaces and flipping, and could only do and test this for surface source spaces.

def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse):
"""Prepare indices and flips for extract_label_time_course."""
# if src is a mixed src space, the first 2 src spaces are surf type and
# the other ones are vol type. For mixed source space n_labels will be the
# given by the number of ROIs of the cortical parcellation plus the number
# of vol src space
from .label import label_sign_flip, Label, BiHemiLabel

# get vertices from source space, they have to be the same as in the stcs
vertno = stc.vertices
nvert = [len(vn) for vn in vertno]

# do the initialization
label_vertidx = list()
label_flip = list()
for s, v, hemi in zip(src, stc.vertices, ('left', 'right')):
    n_missing = (~np.in1d(v, s['vertno'])).sum()
    if n_missing:
        raise ValueError('%d/%d %s hemisphere stc vertices missing from '
                         'the source space, likely mismatch'
                         % (n_missing, len(v), hemi))
bad_labels = list()
for li, label in enumerate(labels):
    if use_sparse:
        assert isinstance(label, dict)
        vertidx = label['csr']
        # This can happen if some labels aren't present in the space
        if vertidx.shape[0] == 0:
            bad_labels.append(label['name'])
            vertidx = None
        # Efficiency shortcut: use linearity early to avoid redundant
        # calculations
        elif mode == 'mean':
            vertidx = sparse.csr_matrix(vertidx.mean(axis=0))
        label_vertidx.append(vertidx)
        label_flip.append(None)
        continue
    # standard case
    _validate_type(label, (Label, BiHemiLabel), 'labels[%d]' % (li,))

    if label.hemi == 'both':
        # handle BiHemiLabel
        sub_labels = [label.lh, label.rh]
    else:
        sub_labels = [label]
    this_vertidx = list()
    for slabel in sub_labels:
        if slabel.hemi == 'lh':
            this_vertices = np.intersect1d(vertno[0], slabel.vertices)
            vertidx = np.searchsorted(vertno[0], this_vertices)
        elif slabel.hemi == 'rh':
            this_vertices = np.intersect1d(vertno[1], slabel.vertices)
            vertidx = nvert[0] + np.searchsorted(vertno[1], this_vertices)
        else:
            raise ValueError('label %s has invalid hemi' % label.name)
        this_vertidx.append(vertidx)

    # convert it to an array
    this_vertidx = np.concatenate(this_vertidx)
    this_flip = None
    if len(this_vertidx) == 0:
        bad_labels.append(label.name)
        this_vertidx = None  # to later check if label is empty
    elif mode not in ('mean', 'max'):  # mode-dependent initialization
        # label_sign_flip uses two properties:
        #
        # - src[ii]['nn']
        # - src[ii]['vertno']
        #
        # So if we override vertno with the stc vertices, it will pick
        # the correct normals.
        with _temporary_vertices(src, stc.vertices):
            this_flip = label_sign_flip(label, src[:2])[:, None]

    label_vertidx.append(this_vertidx)
    label_flip.append(this_flip)

if len(bad_labels):
    msg = ('source space does not contain any vertices for %d label%s:\n%s'
           % (len(bad_labels), _pl(bad_labels), bad_labels))
    if not allow_empty:
        raise ValueError(msg)
    else:
        msg += '\nAssigning all-zero time series.'
        if allow_empty == 'ignore':
            logger.info(msg)
        else:
            warn(msg)

return label_vertidx, label_flip
olafhauk commented 4 years ago

I'll start with something for my own purposes. But I'm not brave enough yet to also change _prepare_label_extraction.

olafhauk commented 4 years ago

I'll start with something for my own purposes. But I'm not brave enough yet to also change _prepare_label_extraction.

_prepare_label_extraction() actually does what I want (returns label_vertidx). Will simply use that.

olafhauk commented 4 years ago

Can I add a helper function _get_source_space_vertices(stc, labels, src, allow_empty) that calls

label_vertidx, _ = _prepare_label_extraction(stc, labels, src, mode=None, allow_empty=allow_empty, use_sparse=False)

and only returns label_vertidx?

I don't think mode and use_sparse are relevant to me.

larsoner commented 4 years ago

Why not just put that one line where you would otherwise put your helper call?

olafhauk commented 4 years ago

Ok, will do. I thought it would be clearer and might be used more often in the future.

olafhauk commented 4 years ago

Just one more (hopefully last) thing... Can I add options stc=None and mode=None to _prepare_label_extraction(). stc and mode are needed for activation timecourses but not to get just the indices.

I'm thinking about changes like this (see CHANGE in three places):


def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse):
    """Prepare indices and flips for extract_label_time_course."""
    # if src is a mixed src space, the first 2 src spaces are surf type and
    # the other ones are vol type. For mixed source space n_labels will be the
    # given by the number of ROIs of the cortical parcellation plus the number
    # of vol src space
    # CHANGE
    # if mode=None and stc=None (i.e. no activation time courses, only compute
    # vertex indices
    from .label import label_sign_flip, Label, BiHemiLabel

    # get vertices from source space, they have to be the same as in the stcs
    # CHANGE
    if stc is not None:
        vertno = stc.vertices

        for s, v, hemi in zip(src, stc.vertices, ('left', 'right')):
            n_missing = (~np.in1d(v, s['vertno'])).sum()
            if n_missing:
                raise ValueError('%d/%d %s hemisphere stc vertices missing from '
                                 'the source space, likely mismatch'
                                 % (n_missing, len(v), hemi))
    else:
        vertno = src['vertno']

    nvert = [len(vn) for vn in vertno]

    # do the initialization
    label_flip = list()
    label_vertidx = list()

    bad_labels = list()
    for li, label in enumerate(labels):
        if use_sparse:
            assert isinstance(label, dict)
            vertidx = label['csr']
            # This can happen if some labels aren't present in the space
            if vertidx.shape[0] == 0:
                bad_labels.append(label['name'])
                vertidx = None
            # Efficiency shortcut: use linearity early to avoid redundant
            # calculations
            elif mode == 'mean':
                vertidx = sparse.csr_matrix(vertidx.mean(axis=0))
            label_vertidx.append(vertidx)
            label_flip.append(None)
            continue
        # standard case
        _validate_type(label, (Label, BiHemiLabel), 'labels[%d]' % (li,))

        if label.hemi == 'both':
            # handle BiHemiLabel
            sub_labels = [label.lh, label.rh]
        else:
            sub_labels = [label]
        this_vertidx = list()
        for slabel in sub_labels:
            if slabel.hemi == 'lh':
                this_vertices = np.intersect1d(vertno[0], slabel.vertices)
                vertidx = np.searchsorted(vertno[0], this_vertices)
            elif slabel.hemi == 'rh':
                this_vertices = np.intersect1d(vertno[1], slabel.vertices)
                vertidx = nvert[0] + np.searchsorted(vertno[1], this_vertices)
            else:
                raise ValueError('label %s has invalid hemi' % label.name)
            this_vertidx.append(vertidx)

        # convert it to an array
        this_vertidx = np.concatenate(this_vertidx)
        this_flip = None
        if len(this_vertidx) == 0:
            bad_labels.append(label.name)
            this_vertidx = None  # to later check if label is empty
        # CHANGE
        elif mode is None:
            this_flip = []
        elif mode not in ('mean', 'max'):  # mode-dependent initialization
            # label_sign_flip uses two properties:
            #
            # - src[ii]['nn']
            # - src[ii]['vertno']
            #
            # So if we override vertno with the stc vertices, it will pick
            # the correct normals.
            with _temporary_vertices(src, stc.vertices):
                this_flip = label_sign_flip(label, src[:2])[:, None]

        label_vertidx.append(this_vertidx)
        label_flip.append(this_flip)

    if len(bad_labels):
        msg = ('source space does not contain any vertices for %d label%s:\n%s'
               % (len(bad_labels), _pl(bad_labels), bad_labels))
        if not allow_empty:
            raise ValueError(msg)
        else:
            msg += '\nAssigning all-zero time series.'
            if allow_empty == 'ignore':
                logger.info(msg)
            else:
                warn(msg)

    return label_vertidx, label_flip
agramfort commented 4 years ago

open a PR it will be easier to discuss on code details