NeuralAnalysis / PyalData

Repository for the Python implementation of the TrialData analysis library.
GNU General Public License v3.0
7 stars 9 forks source link

`restrict_to_interval` has weird behavior when `epoch_fun` output depends on the arrays to be trimmed #134

Closed raeedcho closed 1 year ago

raeedcho commented 2 years ago

When I define an epoch_fun that trims off the nans from a PyalData DataFrame, it can result in different array lengths for different columns of the trial. This is because epoch_fun is applied sequentially to each time-varying column, and each intermediate result is saved back into the trial_data structure.

The offending snippet is lines 77-79 in interval.py:

    # cut time varying signals
    for col in time_fields:
        trial_data[col] = extract_interval_from_signal(trial_data, col, epoch_fun)

Example usage that causes issue:

    def epoch_fun(trial):
        signals = np.column_stack([trial[sig] for sig in ref_signals])
        nan_times = np.any(np.isnan(signals), axis=1)
        first_viable_time = np.nonzero(~nan_times)[0][0]
        last_viable_time = np.nonzero(~nan_times)[0][-1]
        return slice(first_viable_time, last_viable_time + 1)

    td_trimmed = pyaldata.restrict_to_interval(trial_data, epoch_fun=epoch_fun)

Proposed solution: At the very least, there should probably be a warning about how epoch_fun can be defined to avoid this behavior. Ideally, though, the restrict_to_interval function would probably be better off saving each intermediate result into a copy of trial_data rather than overwriting the original at each intermediate step.

bagibence commented 2 years ago

Thanks! I see what you mean.

As a solution, how about just getting rid of extract_interval_from_signal, and instead applying epoch_fun and indexing directly within restrict_to_interval per trial? (I don't see any other mention of extract_interval_from_signal, so it shouldn't break anything else.)

raeedcho commented 2 years ago

Actually, here's my proposed solution that seems to work fine when I implemented it.

Replace lines 77-79 in interval.py with the following:

    # cut time varying signals
    trim_temp = {
        col: extract_interval_from_signal(trial_data, col, epoch_fun)
        for col in time_fields
    }
    trial_data = trial_data.assign(**trim_temp)

It basically creates a temporary python dict of python lists that trims the signals first, and then later assigns them to trial_data. I just created a pull request for this change, so if it works for your data, this might be the simplest change. That said, it's not incompatible with getting rid of extract_interval_from_signal, so we could do that too.

bagibence commented 1 year ago

Yes, looks great, that should work as well. Thanks a lot!

bagibence commented 1 year ago

Fixed by @raeedcho in #135