SpikeInterface / spikeinterface

A Python-based module for creating flexible and robust spike sorting pipelines.
https://spikeinterface.readthedocs.io
MIT License
517 stars 186 forks source link

Error with plotting function when using a 3dim probe #1821

Open Kayv-cmb opened 1 year ago

Kayv-cmb commented 1 year ago

Hi,

I have a small issue using the plotting function of spikeinterface when I use my probe that got 3 dim.

AttributeError                            Traceback (most recent call last)
Cell In[1], line 31
     29 compute_spike_amplitudes(waveform)
     30 compute_correlograms(waveform)
---> 31 sw.plot_unit_summary(waveform, unit_id= 106)
     33 firing_rates = compute_firing_rates(waveform)
     34 isi_violation_ratio, isi_violations_count = compute_isi_violations(waveform)

File ~/.conda/envs/CHIME_kayvan/lib/python3.9/site-packages/spikeinterface/widgets/base.py:120, in define_widget_function_from_class.<locals>.widget_func(*args, **kwargs)
    117 @copy_signature(widget_class)
    118 def widget_func(*args, **kwargs):
    119     W = widget_class(*args, **kwargs)
--> 120     W.do_plot(W.backend, **W.backend_kwargs)
    121     return W.plotter

File ~/.conda/envs/CHIME_kayvan/lib/python3.9/site-packages/spikeinterface/widgets/base.py:53, in BaseWidget.do_plot(self, backend, **backend_kwargs)
     51 plotter = self.possible_backends[backend]()
     52 self.check_backend_kwargs(plotter, backend, **backend_kwargs)
---> 53 plotter.do_plot(self.plot_data, **backend_kwargs)
     54 self.plotter = plotter

File ~/.conda/envs/CHIME_kayvan/lib/python3.9/site-packages/spikeinterface/widgets/matplotlib/unit_summary.py:42, in UnitSummaryPlotter.do_plot(self, data_plot, **backend_kwargs)
     40 if dp.plot_data_unit_locations is not None:
     41     ax1 = fig.add_subplot(gs[:2, 0])
---> 42     UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1)
     43     x, y = dp.unit_location[0], dp.unit_location[1]
     44     ax1.set_xlim(x - 80, x + 80)

File ~/.conda/envs/CHIME_kayvan/lib/python3.9/site-packages/spikeinterface/widgets/matplotlib/unit_locations.py:36, in UnitLocationsPlotter.do_plot(self, data_plot, **backend_kwargs)
     33 if dp.with_channel_ids:
     34     text_on_contact = dp.channel_ids
---> 36 poly_contact, poly_contour = plot_probe(
     37     probe,
     38     ax=self.ax,
     39     contacts_colors="w",
     40     contacts_kargs=contacts_kargs,
     41     probe_shape_kwargs=probe_shape_kwargs,
     42     text_on_contact=text_on_contact,
     43 )
     44 poly_contact.set_zorder(2)
     45 if poly_contour is not None:

File ~/.conda/envs/CHIME_kayvan/lib/python3.9/site-packages/probeinterface/plotting.py:117, in plot_probe(probe, ax, contacts_colors, with_channel_index, with_contact_id, with_device_index, text_on_contact, first_index, contacts_values, cmap, title, contacts_kargs, probe_shape_kwargs, xlims, ylims, zlims, show_channel_on_click)
    114 elif probe.ndim == 3:
    115     poly = Poly3DCollection(
    116         vertices, color=contacts_colors, **_contacts_kargs)
--> 117     ax.add_collection3d(poly)
    119 if contacts_values is not None:
    120     poly.set_array(contacts_values)

AttributeError: 'AxesSubplot' object has no attribute 'add_collection3d'

I assumed not all plotting function would work with a 3 dim probe especially those that require the channel position, but since there seems to be a check of probe.ndim, it should work? Anyway is there a way to fix this error ? Thank you for your help in advance!!!

samuelgarcia commented 1 year ago

I think the problem here is that the summary widget is a multi panel (aka multi axes) stuff. They are all 2d except the probe map you want 3d I am not sure that matplotlib allow this in the same figure. Maybe yes.

Kayv-cmb commented 1 year ago

Thanks for the answer, I am not sure how you're summary widget work, but I am pretty sure you can have both 3D and 2D on the same subplot with matplotlib with something like that as an example.

size = 64
fig = plt.figure(2, (10,4))
ax = fig.add_subplot(121, projection='3d')
plt.scatter(X[:,0], X[:,1], zs=X[:,2], s=size, c='r')
plt.title('Original Points')

ax = fig.add_subplot(122)
plt.scatter(X_transform[:,0], X_transform[:,1], s=size, c='r')
plt.title('Embedding in 2D')
fig.subplots_adjust(wspace=.4, hspace=0.5)
plt.show()