ahwillia / affinewarp

An implementation of piecewise linear time warping for multi-dimensional time series alignment
MIT License
162 stars 36 forks source link

application to human LFP data #21

Open 003084-K opened 1 year ago

003084-K commented 1 year ago

Hi Dr. Williams,

First of all thank you so much for sharing all of these resources openly online. I've already learned so much just by going through your implementations, and your code is so nicely documented.

I have a motor sequence learning dataset in which I record LFP and ECoG data from patients with movement disorders. They learn two different typed sequences (S1 and S2). Each time a fixation cross appears, they type one of the two sequences. I'm hoping to use frequency-domain neural activity during the reaction time period to predict which sequence the patient is about to type using a simple classifier. The total reaction time is highly variable, so I am thinking about trimming the data to the 200ms right before movement onset across all trials, and then apply time warping within that window. I'd follow that with TCA or some other dim reduction before using these as part of my feature vector before feature selection and the running through a simple classifier.

When I apply your TCA code on the (trimmed) raw spectral data just to make sure I am doing things correctly, the results seem to make sense when I compare to some of the basic trends in the data (though it doesn't visibly distinguish between S1 and S2 in the across trial factors unfortunately). S1 is purple dots, S2 is yellow dots.

Screen Shot 2022-11-09 at 7 32 24 PM Screen Shot 2022-11-09 at 7 33 42 PM Screen Shot 2022-11-09 at 7 34 53 PM

When I try to apply the piecewise warping example code, the loss is extremely small, and I'm not sure if it makes any sense..

Screen Shot 2022-11-09 at 7 38 14 PM Screen Shot 2022-11-09 at 7 38 38 PM

Similarly, when I to do a hyperparameter search, the loss is extremely small. However, in the hyperparameter search, there also seems to be no change in loss across iterations, and the results from every random sample draw per fold seem to be identical (I plotted all loss histories for all hyperparameter samples for all models below- the lines for the same models are just overlapping).

Screen Shot 2022-11-09 at 7 39 42 PM

Do you have any idea what I am doing wrong? Is it inappropriate to use these functions on spectral neural data? Any suggestions for alternative methods or change to my overall approach?

code snippet I've been using for piecewise below

import numpy as np
import matplotlib.pyplot as plt

from affinewarp.multiwarp import MultiShiftWarping
from affinewarp.datasets import piecewise_warped_data

from affinewarp import ShiftWarping, PiecewiseWarping

from affinewarp.crossval import paramsearch

knot_range = (-1, 3)
num_models = 3
n_valid_samples = 10

#  g is just lfp channels
for g in [0,1,3]:
    print(g)

    # results in n_trials x n_time x n_centerfreq spectral data
    binned = data_og[:,:,:,g].transpose((2,1,0))

    # Run the parameter search.
    results = paramsearch(
        binned,  # time series data (trials x timebins x features/units)
        num_models,  # number of parameters to randomly sample
        n_valid_samples,  # number of hyperparameter samples per validation set
        n_train_folds=3,  # ratio of data to use for training
        n_valid_folds=1,  # ratio of data to use for validation
        n_test_folds=1,  # ratio of data to use for testing
        knot_range=knot_range,  # range of knots in warping function
        warpreg_range=(1e-3, 1e-2),  # range of warp regularization scale
        smoothness_range=(1e-1, 1e0),  # range of smoothness regularization scale
        iter_range=(50, 51),  # range of optimization iterations
        warp_iter_range=(50, 51), # range of warp iterations
    )

    # plot hyperparamters search results

    train_rsq = results["train_rsq"]
    valid_rsq = results["valid_rsq"]
    test_rsq = results["test_rsq"]
    knots = results["knots"]

    fig, ax = plt.subplots(1, 1, figsize=(6, 3))

    plt.plot(knots-.1, np.median(train_rsq, axis=1), 'ok', label='train', alpha=.5)
    plt.plot(knots, np.median(valid_rsq, axis=1), 'ob', label='validation', alpha=.7)
    plt.plot(knots+.1, test_rsq, 'or', label='test', alpha=.7)

    ax.set_xticks(range(*knot_range))
    ax.set_xticklabels(['shift', 'linear', 'pwise-1', 'pwise-2'])

    ax.set_ylabel("$R^2$")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlabel('warping model', labelpad=7)
    ax.legend()
    plt.title("Channel "+str(g))

    fig.tight_layout()

    plt.figure()
    for i in range(12):
        for j in range(10):
            plt.plot(results['loss_hists'][i,j,:])

# Fit Models
models = [
    ShiftWarping(smoothness_reg_scale=20.0),
    PiecewiseWarping(n_knots=0, warp_reg_scale=1e-6, smoothness_reg_scale=20.0),
    PiecewiseWarping(n_knots=1, warp_reg_scale=1e-6, smoothness_reg_scale=20.0),
    PiecewiseWarping(n_knots=2, warp_reg_scale=1e-6, smoothness_reg_scale=20.0),
]

for m in models:
    try: 
        m.fit(binned, iterations=50, warp_iterations=200)
    except TypeError:
        m.fit(binned, iterations=50)

# Learning curve.
plt.figure()
for m, label in zip(models, ('shift', 'linear', 'pwise-1', 'pwise-2','multiwarp')):
    plt.plot(m.loss_hist, label=label)
plt.legend()
plt.xlabel('iterations')
plt.ylabel('loss')

# plot example before and after alignment

fig, axes = plt.subplots(5, 5, sharex=True, sharey=True, figsize=(10, 5))

plt.suptitle("Channel: "+str(g))
for n, axr in enumerate(axes):

    axr[0].imshow(binned.transpose(2,1,0)[:,:,n])

    axr[1].imshow(models[0].transform(binned).transpose(2,1,0)[:,:,n])
    axr[2].imshow(models[1].transform(binned).transpose(2,1,0)[:,:,n])
    axr[3].imshow(models[2].transform(binned).transpose(2,1,0)[:,:,n])
    axr[4].imshow(models[3].transform(binned).transpose(2,1,0)[:,:,n])

axes[0, 0].set_title("raw data")

axes[0, 1].set_title("shift-only")
axes[0, 2].set_title("linear")
axes[0, 3].set_title("piecewise-1")
axes[0, 4].set_title("piecewise-2")
ahwillia commented 1 year ago

Thanks for trying the code! We applied time warping to LFP data in the time domain in our paper, but I haven't tried it on spectrograms. I think it should work in principle, but it is bizarre to me that your loss values are so low to begin with. How are the data normalized? The loss is simply the mean-squared-error of the trial average minus the single-trial activity, so my only thought is that the data are on a very small scale.

Are the heatmaps you are showing above single-trial activity? It looks by eye like there is a ton of trial-to-trial variability (i.e. very different patterns appearing on each trial) so perhaps the time warping code is struggling to find a single template that describes everything.

I see that you've looked at MultiShiftWarping and this might be the way you need to go. The preprint associated with this is here: https://www.biorxiv.org/content/10.1101/2020.03.02.974014v3

While that preprint sketches out some ideas, I haven't pursued them very deep in practice. Good luck! :slightly_smiling_face: