mne-tools / mne-python

MNE: Magnetoencephalography (MEG) and Electroencephalography (EEG) in Python
https://mne.tools
BSD 3-Clause "New" or "Revised" License
2.72k stars 1.32k forks source link

hierarchical tagging faces problems with equalize_event_counts #2521

Closed teonbrooks closed 9 years ago

teonbrooks commented 9 years ago

Tagging is currently not compatible with the function equalize_event_counts.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/Applications/packages/E-MEG/scripts/priming_sensor_analysis.py in <module>()
     73     epochs.pick_types(meg=True, exclude='bads')
     74 
---> 75     epochs.equalize_event_counts(['unprimed', 'primed'], copy=False)
     76     # plotting grand average
     77     p = epochs.average().plot(show=False)

/Applications/packages/mne-python/mne/epochs.pyc in equalize_event_counts(self, event_ids, method, copy)
   1552             key_match = np.zeros(epochs.events.shape[0])
   1553             for key in eq:
-> 1554                 key_match = np.logical_or(key_match, epochs._key_match(key))
   1555             eq_inds.append(np.where(key_match)[0])
   1556 

/Applications/packages/mne-python/mne/epochs.pyc in _key_match(self, key)
   1297         """Helper function for event dict use"""
   1298         if key not in self.event_id:
-> 1299             raise KeyError('Event "%s" is not in Epochs.' % key)
   1300         return self.events[:, 2] == self.event_id[key]
   1301 

KeyError: 'Event "unprimed" is not in Epochs.'

this is the event_id dictionary:

ipdb> self.event_id
{u'nonword/prime': 9, u'nonword/target': 10, u'word/target/primed': 6, u'fixation': 128, u'word/prime/primed': 5, u'word/prime/unprimed': 1, u'word/target/unprimed': 2}

here's a print of the epochs:

ipdb> self
<EpochsFIF  |  n_events : 669 (all good), tmin : -0.5 (s), tmax : 1.0 (s), baseline : None,
 u'fixation': 240, u'nonword/prime': 48, u'nonword/target': 94, u'word/prime/primed': 48, u'word/prime/unprimed': 143, u'word/target/primed': 48, u'word/target/unprimed': 48>
larsoner commented 9 years ago

I'm not sure if this should be a blocker for 0.10. I think of it more like an enhancement, actually, since we're still building full support for the hierarchical stuff AFAIK. Okay with you to remove the tag @teonlamont?

I think the solution should be fairly simple hopefully -- we must have a function somewhere that already looks for event keys in a hierarchical-compatible way, and we need to repurpose that. But we also will need to be extra careful e.g. that any given event doesn't fall into more than one equalization category.

agramfort commented 9 years ago

I agree. This should be an easy fix. any taker?

teonbrooks commented 9 years ago

Sure sure, it's not a showstopper. It should be simple, I may have time to do it later today or Sunday. when are we planning to release v0.10?

agramfort commented 9 years ago

Sure sure, it's not a showstopper. It should be simple, I may have time to do it later today or Sunday. when are we planning to release v0.10?

over the next 2 weeks

thanks for looking into it.

jona-sassenhagen commented 9 years ago

The hierarchical tagging code is here btw.

jona-sassenhagen commented 9 years ago

I'm not sure how you would like equalize_event_counts to behave when asked to work with hierarchical tags. First of all, the two sets of epochs could easily be overlapping if you use nonorthogonal tags- e.g. if you asked for 'primed' vs. 'word' here, a bunch of trials would show up in both. Then, do you want equalize_event_counts to completely ignore all other trigger information? Then it would be equivalent to something like

epochs_list = epochs['tag1'], epochs['tag2']
event_id = {'tag1':0, 'tag2':1}
doubles = set.intersection([e.events[:,1] for e in epochs_list])
for ii, e in enumerate(epochs_list):
    e.drop_epochs([ii for ii, t in enumerate(e.events[:,1]) if t in doubles])  # if tags are non-orthogonal
    e.event_id = event_id
    e.events[:,2] = ii

new_epochs = mne.epochs.concatenate_epochs(epochs_list)
new_epochs.equalize_event_counts()

If you don't want that, you probably also want it to take in account the other triggers when you try to equalize, which makes for a somewhat more complicated operation and essentially entails setting up equalize_event_counts to work with more than two conditions.

jona-sassenhagen commented 9 years ago

Alternatively, a minimal solution would be adding a line like this to equalize_event_counts

event_ids = [((k for k in self.event_ids if e_i in k.split('/'))
             if ("/" in e_i) and (e_i not in self.event_ids)
             else e_i) for e_i in event_ids]

Plus a check/warning for non-orthogonality (e.g.

if len(set.intersection([self[tag].events[:,1] for tag in event_id]):
    (print a warning)

)

teonbrooks commented 9 years ago

I think dropping the intersection seems like the right thing to do, that is, having them would not contribute any meaningful information to the contrast. I think having a warning of the drop and what you suggested sound like the right way to do it. want to try your hand at it.

jona-sassenhagen commented 9 years ago

You do know I'm in Chicago right now, for the SNL? :)

jona-sassenhagen commented 9 years ago

I can try, but i don't have a reliable schedule this week.

jona-sassenhagen commented 9 years ago

Okay I have an implementation. Current behavior is like this: if you enter hierarchical tags, it internally translates the tag(s) into the list of event_ids that are matched by the tag(s), prunes duplicates, and then treats them as if you had simply called the function like that.

So if event_id.keys() is ['audio/left', 'audio/right', 'visual/left', 'visual/right'],

ids1_= [["audio/left", "visual/left"],["audio/right", "visual/right"]]
ids2 = ['left', 'right']
epochs.equalize_event_counts(ids1).events[:,0] == epochs.equalize_event_counts(ids2).events[:,0]

However, this does not mean you will necessarily end up with an equal number of trials matching 'left' and matching 'right', respectively.

Okay? @teonlamont , is that what you'd think one could expect?

jona-sassenhagen commented 9 years ago

FWIW the code is

    def equalize_event_counts(self, event_ids, method='mintime', copy=True):
        """Equalize the number of trials in each condition

        It tries to make the remaining epochs occurring as close as possible in
        time. This method works based on the idea that if there happened to be
        some time-varying (like on the scale of minutes) noise characteristics
        during a recording, they could be compensated for (to some extent) in
        the equalization process. This method thus seeks to reduce any of
        those effects by minimizing the differences in the times of the events
        in the two sets of epochs. For example, if one had event times
        [1, 2, 3, 4, 120, 121] and the other one had [3.5, 4.5, 120.5, 121.5],
        it would remove events at times [1, 2] in the first epochs and not
        [20, 21].

        Parameters
        ----------
        event_ids : list
            The event types to equalize. Each entry in the list can either be
            a str (single event) or a list of str. In the case where one of
            the entries is a list of str, event_ids in that list will be
            grouped together before equalizing trial counts across conditions.
            In the case where partial matching (using event_ids with '/') is
            used, processing works as if the event_ids matched by the provided
            tags had been supplied instead.
        method : str
            If 'truncate', events will be truncated from the end of each event
            list. If 'mintime', timing differences between each event list will
            be minimized.
        copy : bool
            If True, a copy of epochs will be returned. Otherwise, the
            function will operate in-place.

        Returns
        -------
        epochs : instance of Epochs
            The modified Epochs instance.
        indices : array of int
            Indices from the original events list that were dropped.

        Notes
        -----
        For example (if epochs.event_id was {'Left': 1, 'Right': 2,
        'Nonspatial':3}:

            epochs.equalize_event_counts([['Left', 'Right'], 'Nonspatial'])

        would equalize the number of trials in the 'Nonspatial' condition with
        the total number of trials in the 'Left' and 'Right' conditions.
        """
        if copy is True:
            epochs = self.copy()
        else:
            epochs = self
        if len(event_ids) == 0:
            raise ValueError('event_ids must have at least one element')
        if not epochs._bad_dropped:
            epochs.drop_bad_epochs()
        # figure out how to equalize
        eq_inds = list()

        # deal with hierarchical tags
        ids = epochs.event_id
        if "/" in "".join(ids):
            event_ids = [[x] if isinstance(x, string_types) else x
                         for x in event_ids]
            event_ids = [[k for k in ids if all((tag in k.split("/")
                         for tag in id_))]
                         if all(id__ not in ids for id__ in id_)
                         else id_
                         for id_ in event_ids]

            # deal with non-orthogonal tags
            print(event_ids)
            events_ = [set(epochs[x].events[:, 0]) for x in event_ids]
            doubles = events_[0].intersection(events_[1])
            if len(doubles):
                warnings.warn("Warning: the two sets of epochs are "
                              "overlapping. The %s overlapping epochs will"
                              " be dropped." % len(doubles))
                drop_ids = [ii for ii, t in enumerate(epochs.events[:, 0])
                            if t in doubles]
                epochs.drop_epochs(drop_ids, reason='EQUALIZED_COUNT')

        for eq in event_ids:
            eq = np.atleast_1d(eq)
            # eq is now a list of types
            key_match = np.zeros(epochs.events.shape[0])
            for key in eq:
                key_match = np.logical_or(key_match, epochs._key_match(key))
            eq_inds.append(np.where(key_match)[0])

        event_times = [epochs.events[e, 0] for e in eq_inds]
        indices = _get_drop_indices(event_times, method)
        # need to re-index indices
        indices = np.concatenate([e[idx] for e, idx in zip(eq_inds, indices)])
        epochs.drop_epochs(indices, reason='EQUALIZED_COUNT')
        # actually remove the indices
        return epochs, indices
jona-sassenhagen commented 9 years ago

Okay I guess it's cool. I'll open a PR unless you think there's something dumb about it.

jona-sassenhagen commented 9 years ago

I'll work on MNE during every slide session about fMRI I think.

teonbrooks commented 9 years ago

@jona-sassenhagen I forgot that you were in Chicago!!! i'm sad I'm not there too with my conference buddy :(

teonbrooks commented 9 years ago

closed via #2532