jonescompneurolab / hnn-core

Simulation and optimization of neural circuits for MEG/EEG source estimates
https://jonescompneurolab.github.io/hnn-core/
BSD 3-Clause "New" or "Revised" License
51 stars 50 forks source link

`plot_cells` error when users pass in a generic matplotlib.Axes instance #735

Closed rythorpe closed 3 months ago

rythorpe commented 3 months ago

Found by @wagdy88.

If you create a figure with multiple subplots and then pass one of those subplots into plot_cells() (as can be done for most plotting functions in viz.py), an error is thrown.

For example, running

from hnn_core import jones_2009_model
net = jones_2009_model()
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 1)
net.plot_cells(ax=axes)

produces the following:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 5
      3 import matplotlib.pyplot as plt
      4 fig, axes = plt.subplots(1, 1)
----> 5 net.plot_cells(ax=axes)

File ~/hnn-core/hnn_core/network.py:1418, in Network.plot_cells(self, ax, show)
   1402 def plot_cells(self, ax=None, show=True):
   1403     """Plot the cells using Network.pos_dict.
   1404 
   1405     Parameters
   (...)
   1416         The matplotlib figure handle.
   1417     """
-> 1418     return plot_cells(net=self, ax=ax, show=show)

File ~/hnn-core/hnn_core/viz.py:598, in plot_cells(net, ax, show)
    596         color = colors[cell_type]
    597         marker = markers[cell_type]
--> 598         ax.scatter(x, y, z, c=color, s=50, marker=marker, label=cell_type)
    600 if net.rec_arrays:
    601     cols = plt.get_cmap('inferno', len(net.rec_arrays) + 2)

File ~/anaconda3/envs/hnn_core/lib/python3.9/site-packages/matplotlib/__init__.py:1465, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1462 @functools.wraps(func)
   1463 def inner(ax, *args, data=None, **kwargs):
   1464     if data is None:
-> 1465         return func(ax, *map(sanitize_sequence, args), **kwargs)
   1467     bound = new_sig.bind(ax, *args, **kwargs)
   1468     auto_label = (bound.arguments.get(label_namer)
   1469                   or bound.kwargs.get(label_namer))

TypeError: scatter() got multiple values for argument 's'

I'm pretty sure the issue here is that Matplotlib is trying to project a 3D plot onto a 2D Axes instance, however, this can be confusing to users who aren't familiar with using Matplotlib to plot in 3D. I suggest that we either come up with a patch within plot_cells that converts the Axes projection from 2D to 3D (thus allowing the user to pass in any generic Axes object and have it work) or remove the ax parameter from plot_cells altogether to remove confusion.

jasmainak commented 3 months ago

Fix is welcome !

gtdang commented 3 months ago

I don't think there is a way to update an 2d axis to a 3d projection. https://stackoverflow.com/questions/35209489/can-i-turn-an-existing-ax-object-into-a-3d-projection

Maybe there should be a check on an axes argument and raises an error if it's not a Axes3D object?

rythorpe commented 3 months ago

I was thinking of removing the existing 2D axes and then adding new 3D axes, but this might mess with formatting if the user starts with a more detailed grid of subplots (maybe?). Your suggestion about adding a check for an Axes3D object is probably the easiest solution, however.

gtdang commented 3 months ago

In that case I think the user would have to pass the grid/subplot/figure so you can index which one to replace. I don't think it's typical convention to pass the whole axes array to a plotting function.

rythorpe commented 3 months ago

I believe the figure as accessible via Axes.get_figure(), but I generally agree that the simpler solution where we raise an error if the user attempts plotting on the wrong type of Axes object is the way to go.

jasmainak commented 3 months ago

as long as we catch the error before it becomes a cryptic traceback, I think it's fine

samadpls commented 3 months ago

Hey @rythorpe , can I take this issue, Can you confirm that we just need to check if ax is not an instance of Axes3D and then raise a TypeError?

rythorpe commented 3 months ago

Yep, go for it @samadpls!