pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.5k stars 1.04k forks source link

scatter plot is slow #9129

Open mktippett opened 2 weeks ago

mktippett commented 2 weeks ago

What happened?

scatter plot is slow when the dataset has large (length) coordinates even though those coordinates are not involved in the scatter plot.

What did you expect to happen?

scatter plot speed does not depend on coordinates that are not involved in the scatter plot, which was the case at some point in the past

Minimal Complete Verifiable Example

import numpy as np
import xarray as xr
from matplotlib import pyplot as plt
%config InlineBackend.figure_format = 'retina'
%matplotlib inline

# Define coordinates
month = np.arange(1, 13, dtype=np.int64)
L = np.arange(1, 13, dtype=np.int64)

# Create random values for the variables SP and SE
np.random.seed(0)  # For reproducibility
SP_values = np.random.rand(len(L), len(month))
SE_values = SP_values + np.random.rand(len(L), len(month))

# Create the dataset
ds = xr.Dataset(
    {
        "SP": (["L", "month"], SP_values),
        "SE": (["L", "month"], SE_values)
    },
    coords={
        "L": L,
        "month": month,
        "S": np.arange(250),
        "model": np.arange(7),
        "M": np.arange(30)
    }
)
# slow
ds.plot.scatter(x='SP', y='SE')

ds = xr.Dataset(
    {
        "SP": (["L", "month"], SP_values),
        "SE": (["L", "month"], SE_values)
    },
    coords={
        "L": L,
        "month": month
    }
)
# fast
ds.plot.scatter(x='SP', y='SE')

MVCE confirmation

Relevant log output

No response

Anything else we need to know?

For me, slow = 25 seconds and fast = instantaneous

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:45:13) [Clang 16.0.6 ] python-bits: 64 OS: Darwin OS-release: 23.5.0 machine: x86_64 processor: i386 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.14.3 libnetcdf: 4.9.2 xarray: 2024.6.0 pandas: 2.2.2 numpy: 1.26.4 scipy: 1.13.1 netCDF4: 1.6.5 pydap: installed h5netcdf: 1.3.0 h5py: 3.11.0 zarr: 2.18.2 cftime: 1.6.4 nc_time_axis: 1.4.1 iris: None bottleneck: 1.3.8 dask: 2024.6.0 distributed: 2024.6.0 matplotlib: 3.8.4 cartopy: 0.23.0 seaborn: 0.13.2 numbagg: 0.8.1 fsspec: 2024.6.0 cupy: None pint: 0.24 sparse: 0.15.4 flox: 0.9.8 numpy_groupies: 0.11.1 setuptools: 70.0.0 pip: 24.0 conda: None pytest: 8.2.2 mypy: None IPython: 8.17.2 sphinx: None
headtr1ck commented 2 weeks ago

Thanks for the report and the example. This is really crazy, the slow example takes 2min on my machine while the fast one is basically instant.

After digging a bit, the problem seems to be, that xr.plot.dataset_plot._temp_dataarray broadcasts everything against everything creating a super large array of shape (12, 12, 250, 7, 30)...

@Illviljan do you have any idea? Probably removing unessesary coords before the broadcast might help?