sunpy / ndcube

A base package for multi-dimensional contiguous and non-contiguous coordinate-aware arrays.
http://docs.sunpy.org/projects/ndcube/
BSD 2-Clause "Simplified" License
44 stars 49 forks source link

Cannot pass `coord_params` to underlying plotting method through `plot` #627

Open wtbarnes opened 1 year ago

wtbarnes commented 1 year ago

Describe the bug

When plotting a cube with a number of dimensions higher than 2, the result of .plot() is an animation on the additional axes and is handled by mpl_animators. ArrayAnimatorWCS can take a dictionary coord_params that can be used to apply parameters to axeslabels, ticklabels, etc. However, when I pass this dict through .plot, I get the exception shown below. This seems to be because ndcube is supplying its own coord_params, but is not accounting for a user-specified coord_params:

https://github.com/sunpy/ndcube/blob/481510bb683fc359c4e39ad8b42a89de8e9d68b7/ndcube/visualization/mpl_plotter.py#L193-L194

If a user specifies coord_params, both are passed to mpl-animators, leading to the exception below.

If this is not a bug, but rather a feature request, feel free to relabel. I did assume this would work though because the .plot docstring says,

Additional keyword arguments are given to the underlying plotting infrastructure which depends on the dimensionality of the data and whether 1 or 2 plot_axes are defined

https://docs.sunpy.org/projects/ndcube/en/stable/api/ndcube.visualization.mpl_plotter.MatplotlibPlotter.html#ndcube.visualization.mpl_plotter.MatplotlibPlotter.plot

To Reproduce

import ndcube
import numpy as np
import astropy.wcs

data = np.random.rand(5, 45, 45)
wcs = astropy.wcs.WCS(naxis=3)
wcs.wcs.ctype = 'HPLT-TAN', 'HPLN-TAN', "WAVE"
wcs.wcs.cunit = 'arcsec', 'arcsec', 'Angstrom'
wcs.wcs.cdelt = 10, 10, 0.2
wcs.wcs.crpix = 2, 2, 0
wcs.wcs.crval = 1, 1, 10
wcs.wcs.cname = 'HPC lat', 'HPC lon', 'wavelength'
example_cube = ndcube.NDCube(data, wcs=wcs)

coord_params = {
    'hpln': {
        'axislabel': 'A New Longitude Label'
    }
}

example_cube.plot(coord_params=coord_params)

throws the following exception

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [4], in <cell line: 21>()
     13 example_cube = ndcube.NDCube(data, wcs=wcs)
     15 coord_params = {
     16     'hpln': {
     17         'axislabel': 'A New Longitude Label'
     18     }
     19 }
---> 21 example_cube.plot(coord_params=coord_params)

File ~/mambaforge/envs/mocksipipeline/lib/python3.9/site-packages/ndcube/ndcube.py:869, in NDCube.plot(self, *args, **kwargs)
    864 if self.plotter is None:
    865     raise NotImplementedError(
    866         "This NDCube object does not have a .plotter defined so "
    867         "no default plotting functionality is available.")
--> 869 return self.plotter.plot(*args, **kwargs)

File ~/mambaforge/envs/mocksipipeline/lib/python3.9/site-packages/ndcube/visualization/mpl_plotter.py:87, in MatplotlibPlotter.plot(self, axes, plot_axes, axes_coordinates, axes_units, data_unit, wcs, **kwargs)
     84         ax = self._plot_2D_cube(plot_wcs, axes, plot_axes, axes_coordinates,
     85                                 axes_units, data_unit, **kwargs)
     86     else:
---> 87         ax = self._animate_cube(plot_wcs, plot_axes=plot_axes,
     88                                 axes_coordinates=axes_coordinates,
     89                                 axes_units=axes_units, data_unit=data_unit, **kwargs)
     91 return ax

File ~/mambaforge/envs/mocksipipeline/lib/python3.9/site-packages/ndcube/visualization/mpl_plotter.py:195, in MatplotlibPlotter._animate_cube(self, wcs, plot_axes, axes_coordinates, axes_units, data_unit, **kwargs)
    190 def _animate_cube(self, wcs, plot_axes=None, axes_coordinates=None,
    191                   axes_units=None, data_unit=None, **kwargs):
    192     # Derive inputs for animation object and instantiate.
    193     data, wcs, plot_axes, coord_params = self._prep_animate_args(wcs, plot_axes,
    194                                                                  axes_units, data_unit)
--> 195     ax = ArrayAnimatorWCS(data, wcs, plot_axes, coord_params=coord_params, **kwargs)
    197     # We need to modify the visible axes after the axes object has been created.
    198     # This call affects only the initial draw
    199     self._apply_axes_coordinates(ax.axes, axes_coordinates)

TypeError: mpl_animators.wcs.ArrayAnimatorWCS() got multiple values for keyword argument 'coord_params'

Screenshots

No response

System Details

Installation method

pip

DanRyanIrish commented 1 year ago

Hi @wtbarnes. Thanks for raising this issue. I would class it as a bug. After a quick glance I think this should be fixed simply by extracting user input coord_params and updating the auto-generated ones before line 195 of mpl_plotter.py, e.g.

user_coord_params = kwargs.pop("coord_params")
coord_params.update(user_coord_params)

A test should be added to ensure this works. Would you be interested in opening a PR to fix implement this change?

wtbarnes commented 1 year ago

Sure! That seems like a reasonable solution. Re: the added test, should that be a figure test?

DanRyanIrish commented 1 year ago

Yes, I think if possible this should be a figure test. I'm not that expert on figure tests when the output is an animation. Perhaps you can compare with any other such tests in ndcube or ask @Cadair.