int-brain-lab / iblenv

Unified environment and Issue tracker for all IBL
MIT License
10 stars 10 forks source link

[EXTRACT] Flat/zero waveforms in PyKS template extraction #344

Closed chris-langfield closed 2 months ago

chris-langfield commented 9 months ago

Han/Chris are recomputing templates from PyKS units in ephys atlas pids (~1000 insertions). Reports that some templates (10-20%) contain spike waveforms that are all zeros. Some templates have a majority of spikes being blank, potentially throwing off the averaging.

Extraction script: https://github.com/int-brain-lab/ibldevtools/blob/master/chris/spikesorting/2023-09-28_template_extraction.py

Location of data (npz format) on SDSC:

/mnt/home/clangfield/ceph/ephys_atlas_templates

Haansololfp commented 9 months ago

example insertion (pid = '0c15a331-09ac-445c-837f-6afb5e377e56') with flat waveform contamination: 1) example units in the insertion Screen Shot 2023-10-02 at 4 30 22 PM 2) single waveforms for an example unit unknown-24

chris-langfield commented 8 months ago

@Haansololfp states zero-waveforms are across all 384 channels, can use the following code to find the anomalous waveforms from the saved data


import numpy as np
import matplotlib.pyplot as plt
from one.remote import aws
import pandas as pd
from ibllib.atlas import BrainRegions
from pathlib import Path
import pandas as pd
import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch

main_path = '/mnt/home/clangfield/ceph/ephys_atlas_templates'

pids = []
for path in Path(main_path).rglob('*_ks_template_region.npz'):
    pID = str(path)[48:(48+36)]
    pids.append(pID) 
np.array(pids)

LOCAL_DATA_PATH = Path('/mnt/home/hyu10/ceph/analysis').joinpath('latest')
df_channels = pd.read_parquet(LOCAL_DATA_PATH.joinpath('channels.pqt'))
pids_in = np.intersect1d(pids, np.unique(df_channels.reset_index(drop=False)['pid'].values))

LOCAL_DATA_PATH = Path("/mnt/sdceph/users/hyu10/analysis/latest")
df_channels = pd.read_parquet(LOCAL_DATA_PATH.joinpath('channels.pqt'))

import os
import pandas as pd

class TemplateDataset(Dataset):
    def __init__(self, PIDs, main_path):
        self.pids = PIDs
        self.main_path = main_path

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx):
        temp_dir = self.main_path + '/' + self.pids[idx] + '/' + self.pids[idx] + '_ks_template_region.npz'
        data = np.load(temp_dir)
        try:
            waveforms = np.array(data['waveforms'])
        except:
            waveforms = np.array([])
            return waveforms

        return waveforms

from scipy.spatial import distance
template_dataset = TemplateDataset(pids_in, main_path)
template_loader = DataLoader(dataset=template_dataset, batch_size = 1)#, num_workers=5, prefetch_factor=5)

flat_pids = []
zero_frequencies = []
unit_regions = []

for idx, waveforms in enumerate(template_loader):
    pid = pids_in[idx]
    waveforms = np.squeeze(waveforms, axis = 0)

    if waveforms.shape[0]==0:
        continue

    df_channels = pd.read_parquet(LOCAL_DATA_PATH.joinpath('channels.pqt'))

    df_channels = df_channels.reset_index(drop=False)
    df_channels = df_channels[df_channels.pid == pid]
    df_channels = df_channels.reset_index(drop=True)

    CH_regions = np.array(df_channels['acronym'].values)

    N, n, T, C = np.shape(waveforms)
    all_zero_cases = torch.all(torch.all(waveforms == 0, 2),2)
    all_zero_idx = np.where(torch.any(all_zero_cases, 1))[0]

    if len(all_zero_idx)>0:
        zero_frequency = torch.divide(torch.sum(waveforms[all_zero_idx,:,0,0] == 0 , 1), torch.sum(~torch.isnan(waveforms[all_zero_idx,:,0,0]) , 1))
        wf_ptps = torch.max(torch.nanmean(waveforms[all_zero_idx,:,:,:], 1), 1)[0] - torch.min(torch.nanmean(waveforms[all_zero_idx,:,:,:], 1), 1)[0]
        mcs = torch.argmax(wf_ptps, 1)

        regions = CH_regions[mcs]

        flat_pids.append(pid)
        zero_frequencies.append(zero_frequency)
        unit_regions.append(regions)
chris-langfield commented 6 months ago
pid = ‘3282a590-8688-44fc-9811-cdf8b80d9a80’ and uuid = ‘617c08fe-8817-4812-bdcd-ecc6a97f6ce5’

Image

chris-langfield commented 6 months ago

Cluster above has only 825 waveforms but saved as 1000x121x384 array. Possibly NaN excess in array creating flat waveforms?

chris-langfield commented 6 months ago

Actually even excluding those the following WF indices are zero (i.e. the entire 121x384 block is identically 0):

[408, 410, 411, 412, 413, 414, 415, 416, 417, 418, 420]

k1o0 commented 2 months ago

Now resolved by new waveform extraction.