SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
526 stars 187 forks source link

Sorting analyzer behavior with multi probe recording #3383

Open RobertoDF opened 2 months ago

RobertoDF commented 2 months ago

Hi, I have a problem with the following pipeline: I have a 2 NPX1 recording. I sort each probe individually. I curate each one in phy. I load each sorting with read_phy and aggregate them with aggregate_units. I compute extensions on this aggregated sorting.

The waveforms from this aggregated sorting analyzer are, however, off. If I plot them using plot_units_waveforms, I see four columns of waveforms, two columns are fine, two columns are flat lines (I'll post a pic later). Notably, if I plot using a sorting_analyzer created from an individual sorting (not the aggregated one) waveforms are fine. I guess the waveform extraction is not noticing the probe group of each channel and picking the channels corresponding to the peak template in both probes? Not sure though.

Should I use the sorting analyzer on not aggregated sorting objects only?if yes shouldn't there be some correspondent method to aggregate_units for the sorting analyzers too?

Thanks!

zm711 commented 2 months ago

I think the picture would help. I also sort with probe groups and I haven't noticed this problem before. The sparsity should take care of this I believe. Are you creating your analyzer with sparsity = False to get dense info?

RobertoDF commented 2 months ago

They look all like this:

Screenshot 2024-09-07 at 23 50 51 Screenshot 2024-09-07 at 23 51 37
RobertoDF commented 2 months ago

this is the code I use, I don´t explicitly set sparsity:

sortings = []
for probe_n in tqdm(range(len(raw_rec.get_probes()))):
            sortings.append(read_phy(Path(f"{path_recording_folder}/spike_interface_output/probe{probe_n}/sorter_output")))

sorting = aggregate_units(sortings)

analyzer = create_sorting_analyzer(sorting, raw_rec, sparse=True, format="binary_folder",
                                           folder=f"{path_recording_folder}/spike_interface_output/sorting_analyzer_post_curation",
                                           overwrite=True, **job_kwargs)
zm711 commented 2 months ago

How are you sorting to start with? What is the script for sorting to get to the phy data step? Are you doing run_sorter_by_property or something else?

alejoe91 commented 2 months ago

I think I know what's going on! How are you constructing the probe group? The two probes should be separated by a certain distance! The default sparsity looks for channels in the vicinity, so if the two probes are almost overlapping, then half of the channels will be noisy and the templates will be flat. Can you print the recording channel locations and the channel groups?

RobertoDF commented 2 months ago

I sort with this:

path_recording, rec_file_name = find_rec_file(path_recording_folder)
timestamps = get_timestamps_from_rec(path_recording_folder)

raw_rec = read_spikegadgets(path_recording)
raw_rec.set_times(timestamps) # this is useful for other parts of the code

split_preprocessed_recording = raw_rec.split_by("group")
  for probe_n, sub_rec in split_preprocessed_recording.items():
      binary_file_path = f"{path_recording_folder}/spike_interface_output/probe{probe_n}/sorter_output/recording.dat"
      probe_filename = f"{path_recording_folder}/spike_interface_output/probe{probe_n}/sorter_output/probe.prb"
      if not os.path.exists(binary_file_path):
          os.makedirs(Path(binary_file_path).parent, exist_ok=True)
          write_binary_recording(
              recording=sub_rec,
              file_paths=binary_file_path, **job_kwargs)

          pg = sub_rec.get_probegroup()
          write_prb(probe_filename, pg)
          print(f"probe {probe_n} completed \n")

      probe = load_probe(probe_filename)
      settings = {'filename': binary_file_path,
                  "n_chan_bin": probe["n_chan"], "fs": sub_rec.get_sampling_frequency()}#try nskip 20

      result_dir = path_recording_folder / "spike_interface_output" / f"probe{probe_n}" / "sorter_output"
      os.makedirs(result_dir, exist_ok=True)

      run_kilosort(settings=settings, probe=probe, data_dtype=sub_rec.get_dtype(),
                   device=torch.device("cuda"), results_dir=result_dir, clear_cache=True)   

I want to eventually move everything within SpikeInterface but I got some cuda OOM errors and it was easy to investigate by using directly kilosort

RobertoDF commented 2 months ago
Screenshot 2024-09-08 at 10 25 27

Actually there is some distance on the x axis (0 column in df) and the x coords are rightly assigned between groups. Channels belonging to group 0 have the lower set of x coords (-24, -8, 8, 24). Btw, Channel ids are strings in this format:'735' '734' '671' '670' '607' '606' '543' '542' '479' '478'....

RobertoDF commented 2 months ago

Regarding sparsity:

Screenshot 2024-09-08 at 11 57 35
zm711 commented 2 months ago

But what Alessio is saying is that the sparsity default is 50-100 um (I can't remember which on the top of my head). So since you have 4 electrodes sets all within that distance it is including all of them in the range even if a chunk of them are noisy. Alessio knows NP better than I do. How many shanks do you have total? (he can comment better when he's back on) but just to get some more info.

RobertoDF commented 2 months ago

I have 2 NPX1 per recording. Each NPX will have 4 different xcoords. There aren´t noisy channels really. If I create the sort_analyzer from only one sorting (not the aggregated sorting), all the waveforms look fine. The problem happens only when creating the sorting analyzer from the aggregated sorting somehow. The confusing thing is that the channels location looks fine to me and definitely more distant than the default radius https://github.com/SpikeInterface/spikeinterface/blob/d5f1481ffce3a29aabedc95920af5d091b1e1720/src/spikeinterface/core/sparsity.py#L541-L552 The behavior would be explained if waveforms are selected also from channels belonging to the other probe that expetedly would just be noisy.

RobertoDF commented 2 months ago

Ok channel selected from here are messed up. I will try to understand the precise reason later or tmrrw. https://github.com/SpikeInterface/spikeinterface/blob/d5f1481ffce3a29aabedc95920af5d091b1e1720/src/spikeinterface/core/sparsity.py#L299-L316

RobertoDF commented 2 months ago

Somehow some channels that belong to group1 have x coord corresponding to group0. I guess that´s actually a Neo import bug.

RobertoDF commented 2 months ago

Moved to NeuralEnsemble/python-neo#1548

alejoe91 commented 2 months ago

I think the problem is in the read_spikegadgets in probeinterface actually!

RobertoDF commented 2 months ago

I am reopening this because I'm not sure any more it is a neo or probeinterface problem

test rec file here: https://www.dropbox.com/s/bsjd1ipx430dr00/20240124_171519.rec?dl=0

In this case the xcoord of group 0 will contain array([ 8. -24. 258. 226.])

raw_rec = read_spikegadgets(Path(r"X:\data\12\ephys\20240124_171519.rec\20240124_171519.rec"), use_names_as_ids=False )

channel_locations = raw_rec.get_channel_locations()
locations_group_df = pd.concat([pd.DataFrame(channel_locations, index=raw_rec.channel_ids, columns=["x", "y"]), pd.DataFrame(raw_rec.get_channel_groups(), index=raw_rec.channel_ids, columns=["group"])], axis=1)
print(locations_group_df.query("group == 0")["x"].unique())
[  8. -24. 258. 226.]

Here the xcoord of group 0 will containarray([ 8., -24., 24., -8.])

sub_rec = raw_rec.split_by("group")[0]

channel_locations = sub_rec.get_channel_locations()
locations_group_df = pd.concat([pd.DataFrame(channel_locations, index=sub_rec.channel_ids, columns=["x", "y"]), pd.DataFrame(sub_rec.get_channel_groups(), index=sub_rec.channel_ids, columns=["group"])], axis=1)
locations_group_df.query("group == 0")["x"].unique()
array([  8., -24.,  24.,  -8.])

I can't understand what's going on here.....

RobertoDF commented 2 months ago

The problem is solved if I set raw_rec.set_property("contact_vector", None). Doing so I directly access locations here: https://github.com/SpikeInterface/spikeinterface/blob/358e0d31df092ed4516f51f2277b2467234d6b6f/src/spikeinterface/core/baserecordingsnippets.py#L347-L367 By solved I mean that plot_unit_waveforms shows sensible results Capture

zm711 commented 2 months ago

@alejoe91 will be better to comment this deep. That seems like a bug to me (ie we shouldn't force someone to set contact_vector to None for this type of operation, but I'll wait for him to look this over.

RobertoDF commented 2 months ago

Sure, totally agree, I am looking further into it.

RobertoDF commented 2 months ago

Oh It actually doesnt solve the original problem.

This test here produces sensible results https://github.com/SpikeInterface/spikeinterface/issues/3383#issuecomment-2340197858 but plot_unit_waveforms is off
Capture

RobertoDF commented 2 months ago

ok I updated the function get_channel_location also in sorting analyzer Capture