scverse / spatialdata-plot

Static plotting for spatialdata
BSD 3-Clause "New" or "Revised" License
37 stars 14 forks source link

Points not transformed when `method="datashader"` #337

Open clwgg opened 2 months ago

clwgg commented 2 months ago

As per the title, I just ran into a case where datashader was chosen as the method for render_points, which led to my points being plotted without the relevant transformation being applied. I stole the example from https://github.com/scverse/spatialdata-plot/issues/182 for testing below.

from spatialdata import SpatialData
from spatialdata.models import Image2DModel, PointsModel
from spatialdata.transformations import Scale
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import spatialdata_plot

sdata = SpatialData(
    images={
        "image1": Image2DModel.parse(
            np.full((10, 10, 3), fill_value=128), dims=("y", "x", "c")
        )
    },
    points={
        "points1": PointsModel.parse(
            pd.DataFrame({"y": [0.1, 0.1, 0.9, 0.9], "x": [0.1, 0.9, 0.9, 0.1]}),
            transformations={"global": Scale([10, 10], ("y", "x"))},
        )
    },
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
sdata.pl.render_images("image1").pl.render_points("points1", method="datashader").pl.show(ax=ax1, title="datashader")
sdata.pl.render_images("image1").pl.render_points("points1", method="matplotlib").pl.show(ax=ax2, title="matplotlib")

With current main: Screenshot 2024-08-28 at 10 01 40 PM

With https://github.com/scverse/spatialdata-plot/pull/309: Screenshot 2024-08-28 at 10 02 09 PM

LucaMarconato commented 2 months ago

Thanks for reporting. Please see the discussion on this issue also here: https://github.com/scverse/spatialdata-plot/issues/291.

Marius1311 commented 2 months ago

I think I'm seeing a consequence of that in my own data. Calling

(
    sdata_cropped
    .pl.render_points(TRANSCRIPT_KEY, size=1, color="red", method="matplotlib")
    .pl.show()
)

works just fine, but when using method="datashader", I get

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[45], line 4
      1 (
      2     sdata_cropped
      3     .pl.render_points(TRANSCRIPT_KEY, size=1, color="red", method="datashader")
----> 4     .pl.show()
      5 )

File [/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/basic.py:895](http://localhost:50163/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/basic.py#line=894), in PlotAccessor.show(self, coordinate_systems, legend_fontsize, legend_fontweight, legend_loc, legend_fontoutline, na_in_legend, colorbar, wspace, hspace, ncols, frameon, figsize, dpi, fig, title, share_extent, pad_extent, ax, return_ax, save)
    890     wanted_elements, wanted_points_on_this_cs, wants_points = _get_wanted_render_elements(
    891         sdata, wanted_elements, params_copy, cs, "points"
    892     )
    894     if wanted_points_on_this_cs:
--> 895         _render_points(
    896             sdata=sdata,
    897             render_params=params_copy,
    898             coordinate_system=cs,
    899             ax=ax,
    900             fig_params=fig_params,
    901             scalebar_params=scalebar_params,
    902             legend_params=legend_params,
    903         )
    905 elif cmd == "render_labels" and has_labels:
    906     wanted_elements, wanted_labels_on_this_cs, wants_labels = _get_wanted_render_elements(
    907         sdata, wanted_elements, params_copy, cs, "labels"
    908     )

File [/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/render.py:483](http://localhost:50163/nas/groups/treutlein/USERS/mlange/github/spatialdata-plot/src/spatialdata_plot/pl/render.py#line=482), in _render_points(sdata, render_params, coordinate_system, ax, fig_params, scalebar_params, legend_params)
    466     color_vector = np.asarray([x[:-2] for x in color_vector])
    468 ds_result = (
    469     ds.tf.shade(
    470         ds.tf.spread(agg, px=px),
   (...)
    481     )
    482 )
--> 483 rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2))
    484 cax = ax.imshow(rbga_image, zorder=render_params.zorder, alpha=render_params.alpha)
    485 if aggregate_with_sum is not None:

File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:655](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=654), in transpose(a, axes)
    588 @array_function_dispatch(_transpose_dispatcher)
    589 def transpose(a, axes=None):
    590     """
    591     Returns an array with axes transposed.
    592 
   (...)
    653 
    654     """
--> 655     return _wrapfunc(a, 'transpose', axes)

File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:56](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=55), in _wrapfunc(obj, method, *args, **kwds)
     54 bound = getattr(obj, method, None)
     55 if bound is None:
---> 56     return _wrapit(obj, method, *args, **kwds)
     58 try:
     59     return bound(*args, **kwds)

File [/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py:45](http://localhost:50163/links/groups/treutlein/USERS/mlange/miniforge3/envs/spatialdata/lib/python3.11/site-packages/numpy/core/fromnumeric.py#line=44), in _wrapit(obj, method, *args, **kwds)
     43 except AttributeError:
     44     wrap = None
---> 45 result = getattr(asarray(obj), method)(*args, **kwds)
     46 if wrap:
     47     if not isinstance(result, mu.ndarray):

ValueError: axes don't match array
LucaMarconato commented 1 month ago

@Marius1311 thanks for reporting. How did you construct sdata_cropped? It would be helpful for us if you could please reproduce your bug using the blobs dataset.

You can access it via one of these two functions:

CC @melonora

Sonja-Stockhaus commented 1 month ago

@clwgg Thanks for reporting! I reproduced the problem without the image in the background which led to the points being shifted by 0.5 when using datashader (because of #216).

from spatialdata import SpatialData
from spatialdata.models import PointsModel
from spatialdata.transformations import Scale

sdata = SpatialData(
    points={
        "points1": PointsModel.parse(
            pd.DataFrame({"y": [0, 0, 10, 10, 4, 6, 4, 6], "x": [0, 10, 10, 0, 4, 6, 6, 4]}),
            transformations={"global": Scale([2, 2], ("y", "x"))},
        )
    },
)
sdata.pl.render_points("points1", method="matplotlib", size=50, color="lightgrey").pl.render_points("points1", method="datashader", size=10, color="red").pl.show()

With this, I get a) before:

Image

b) after my fix (#378): Image

timtreis commented 3 weeks ago

@clwgg could you verify that Sonja's branch fixes the issue for you as well? :) Thanks!