Closed chris-langfield closed 2 months ago
example insertion (pid = '0c15a331-09ac-445c-837f-6afb5e377e56') with flat waveform contamination:
1) example units in the insertion
2) single waveforms for an example unit
@Haansololfp states zero-waveforms are across all 384 channels, can use the following code to find the anomalous waveforms from the saved data
import numpy as np
import matplotlib.pyplot as plt
from one.remote import aws
import pandas as pd
from ibllib.atlas import BrainRegions
from pathlib import Path
import pandas as pd
import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
main_path = '/mnt/home/clangfield/ceph/ephys_atlas_templates'
pids = []
for path in Path(main_path).rglob('*_ks_template_region.npz'):
pID = str(path)[48:(48+36)]
pids.append(pID)
np.array(pids)
LOCAL_DATA_PATH = Path('/mnt/home/hyu10/ceph/analysis').joinpath('latest')
df_channels = pd.read_parquet(LOCAL_DATA_PATH.joinpath('channels.pqt'))
pids_in = np.intersect1d(pids, np.unique(df_channels.reset_index(drop=False)['pid'].values))
LOCAL_DATA_PATH = Path("/mnt/sdceph/users/hyu10/analysis/latest")
df_channels = pd.read_parquet(LOCAL_DATA_PATH.joinpath('channels.pqt'))
import os
import pandas as pd
class TemplateDataset(Dataset):
def __init__(self, PIDs, main_path):
self.pids = PIDs
self.main_path = main_path
def __len__(self):
return len(self.pids)
def __getitem__(self, idx):
temp_dir = self.main_path + '/' + self.pids[idx] + '/' + self.pids[idx] + '_ks_template_region.npz'
data = np.load(temp_dir)
try:
waveforms = np.array(data['waveforms'])
except:
waveforms = np.array([])
return waveforms
return waveforms
from scipy.spatial import distance
template_dataset = TemplateDataset(pids_in, main_path)
template_loader = DataLoader(dataset=template_dataset, batch_size = 1)#, num_workers=5, prefetch_factor=5)
flat_pids = []
zero_frequencies = []
unit_regions = []
for idx, waveforms in enumerate(template_loader):
pid = pids_in[idx]
waveforms = np.squeeze(waveforms, axis = 0)
if waveforms.shape[0]==0:
continue
df_channels = pd.read_parquet(LOCAL_DATA_PATH.joinpath('channels.pqt'))
df_channels = df_channels.reset_index(drop=False)
df_channels = df_channels[df_channels.pid == pid]
df_channels = df_channels.reset_index(drop=True)
CH_regions = np.array(df_channels['acronym'].values)
N, n, T, C = np.shape(waveforms)
all_zero_cases = torch.all(torch.all(waveforms == 0, 2),2)
all_zero_idx = np.where(torch.any(all_zero_cases, 1))[0]
if len(all_zero_idx)>0:
zero_frequency = torch.divide(torch.sum(waveforms[all_zero_idx,:,0,0] == 0 , 1), torch.sum(~torch.isnan(waveforms[all_zero_idx,:,0,0]) , 1))
wf_ptps = torch.max(torch.nanmean(waveforms[all_zero_idx,:,:,:], 1), 1)[0] - torch.min(torch.nanmean(waveforms[all_zero_idx,:,:,:], 1), 1)[0]
mcs = torch.argmax(wf_ptps, 1)
regions = CH_regions[mcs]
flat_pids.append(pid)
zero_frequencies.append(zero_frequency)
unit_regions.append(regions)
pid = ‘3282a590-8688-44fc-9811-cdf8b80d9a80’ and uuid = ‘617c08fe-8817-4812-bdcd-ecc6a97f6ce5’
Cluster above has only 825 waveforms but saved as 1000x121x384 array. Possibly NaN excess in array creating flat waveforms?
Actually even excluding those the following WF indices are zero (i.e. the entire 121x384 block is identically 0):
[408, 410, 411, 412, 413, 414, 415, 416, 417, 418, 420]
Now resolved by new waveform extraction.
Han/Chris are recomputing templates from PyKS units in ephys atlas pids (~1000 insertions). Reports that some templates (10-20%) contain spike waveforms that are all zeros. Some templates have a majority of spikes being blank, potentially throwing off the averaging.
Extraction script: https://github.com/int-brain-lab/ibldevtools/blob/master/chris/spikesorting/2023-09-28_template_extraction.py
Location of data (npz format) on SDSC:
/mnt/home/clangfield/ceph/ephys_atlas_templates