pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.63k stars 1.09k forks source link

Problems with apply_ufunc on chunked array for function with kwargs #9657

Closed jkoell closed 1 month ago

jkoell commented 1 month ago

What happened?

I've been trying to make my code more memory eifficient, and as part of that am trying to work more with chunked arrays and apply_ufunc. I'm using a function that does calculations on my dataset and returns new variables with the same dimensions. The function has several inputs with dims [time, lat, lon] that are taken as positional arguments, and some optional ones with the same dimensions that are taken as keyword-only arguments.

When I input data that are chunked in time, I get an error that seems to suggest that the positional arguments are passed as chunks, while the keyword argument passes the entire DataArray at once. E.g. in the minimum example below, the error is operands could not be broadcast together with shapes (10, 5, 8) (100, 5, 8)

I couldn't find anything in the apply_ufunc documentation or the apply_ufunc tutorial that discussed this specific problem

What did you expect to happen?

I expected all data to be passed as chunks, and return a chunked DataArray of the same size as the output

Minimal Complete Verifiable Example

import numpy as np
import pandas as pd
import xarray as xr

data = xr.Dataset({'x1': (['time', 'lat', 'lon'], np.random.rand(100, 5, 8)), 
                    'x2': (['time', 'lat', 'lon'], np.random.rand(100, 5, 8))}, 
                    coords={'time':pd.date_range('2000-01-01', periods=100), 
                            'lat': np.arange(5), 'lon':np.arange(8)})

data_ch = data.chunk({'time':10})

def squared_sum(x1, *, x2=1):
    return (x1**2 + x2**2)

out = xr.apply_ufunc(squared_sum, data_ch['x1'],  dask='parallelized', 
               kwargs = {'x2': data_ch['x2']})

out.compute()

MVCE confirmation

Relevant log output

out.compute()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/anaconda3/lib/python3.11/site-packages/xarray/core/dataarray.py", line 1202, in compute
    return new.load(**kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/site-packages/xarray/core/dataarray.py", line 1170, in load
    ds = self._to_temp_dataset().load(**kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/site-packages/xarray/core/dataset.py", line 870, in load
    evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
                                                       ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/site-packages/xarray/namedarray/daskmanager.py", line 86, in compute
    return compute(*data, **kwargs)  # type: ignore[no-untyped-call, no-any-return]
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/site-packages/dask/base.py", line 660, in compute
    results = schedule(dsk, keys, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<stdin>", line 2, in squared_sum
  File "/opt/anaconda3/lib/python3.11/site-packages/xarray/core/arithmetic.py", line 83, in __array_ufunc__
    return apply_ufunc(
           ^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/site-packages/xarray/core/computation.py", line 1278, in apply_ufunc
    return apply_dataarray_vfunc(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/site-packages/xarray/core/computation.py", line 320, in apply_dataarray_vfunc
    result_var = func(*data_vars)
                 ^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/site-packages/xarray/core/computation.py", line 831, in apply_variable_ufunc
    result_data = func(*input_data)
                  ^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/site-packages/dask/array/core.py", line 4776, in broadcast_shapes
    raise ValueError(
ValueError: operands could not be broadcast together with shapes (10, 5, 8) (100, 5, 8)

Anything else we need to know?

I also tried setting dask_gufunc_kwargs={'allow_rechunk':True}), but still receive the same error

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.11.0 | packaged by conda-forge | (main, Jan 14 2023, 12:26:40) [Clang 14.0.6 ] python-bits: 64 OS: Darwin OS-release: 24.0.0 machine: arm64 processor: arm byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.12.2 libnetcdf: 4.9.1 xarray: 2024.9.0 pandas: 2.2.2 numpy: 1.26.4 scipy: None netCDF4: 1.6.3 pydap: None h5netcdf: None h5py: None zarr: None cftime: 1.6.4 nc_time_axis: None iris: None bottleneck: 1.4.1 dask: 2024.9.1 distributed: 2024.9.1 matplotlib: 3.9.2 cartopy: None seaborn: None numbagg: None fsspec: 2024.6.1 cupy: None pint: None sparse: None flox: None numpy_groupies: None setuptools: 75.1.0 pip: 24.2 conda: 24.9.2 pytest: None mypy: None IPython: None sphinx: None
welcome[bot] commented 1 month ago

Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!

dcherian commented 1 month ago

I get an error that seems to suggest that the positional arguments are passed as chunks, while the keyword argument passes the entire DataArray at once.

Yes exactly. You'll need to pass it positionally.

jkoell commented 1 month ago

Ok, thanks for the quick response! Is there any way to get around passing them positionally? As I mentioned above these are keyword-only arguments, i.e. I can't define them positionally. Obviously in my simple example I could just rewrite the squared_sum function to accept them as positional arguments rather than keyword-only. But in my actual code I'm using a repo developed by someone else, so I would prefer If I don't have to change the underlying function

dcherian commented 1 month ago

You can use a tiny wrapper or adapter function to translate between positional and keyword args

jkoell commented 1 month ago

That worked, thanks! Below is the solution for my minimal example, in case anyone is searching for the problem in the future and comes across this thread: def ss_wrapper(x1, x2): return (squared_sum(x1, x2=x2))

out2 = xr.apply_ufunc(ss_wrapper, data_ch['x1'], data_ch['x2'], dask='parallelized')