Open jgrivault opened 4 days ago
Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!
telling me that the graph is huge (between 800MB and 1.3GB)
something is very wrong here but we can't reproduce your issue without datastructure.nc
. Can you show us the repr for ds
(after the open_mfdataset
call).
FWIW 200GB+ isn't that big a dataset, so something else is going wrong here (potentially bad default with open_mfdataset
https://github.com/pydata/xarray/issues/8778)
Sorry, there is the same code without the data_structure. I'll make the file available tomorrow, GiHub doesn't let me attach netcdf files. With this code, the error is slightly different, but it still crash on the msgpack code (see below)
I bet at least one issue is this recursive concat
data = xr.concat([data, ds], dim='source')
I encourage you to use a single open_mfdataset call to open and concatenate all datasets.
I added the data_structure.nc
in the original post (ziped).
I tried with a single concat
instead of my recursive one, it didn't change anything. But out of curiosity, why the recursive approach is a bad idea?
After a bit more testing, my workaround performances utterly awful. I thought that taking processing the dataset year by year would allow to process the whole dataset, despite the dask performance warning, but not really. It worked on my 2 years test data (~ 200GB), but when applied on the whole dataset (120 yrs, several TB), it takes more than 24h to process a single year, for a single dataarray, and it's still running.
As a comparison, I wrote a version of the same processing without using dask-backed dataset, and the processing took 4h30 on the whole 6 climate experiment (several TB).
There is the associated workround code. This replace the .compute()
from the original code.
What is your issue?
Chaining many processing on a huge (200GB+) dask-backed datasets lead to huge graphs (500MB+) being passed. More data, bigger the graph is, to the point where the graph is so huge (31GB at my maximum) that the .compute() fails, with a "error to serialize" error in msgpack.
This is a problem that started when we started to use xarray to process climate experiments. The amount of data that we load is huge (200GB+ on my initial tests. Several TB in the real case). I do not have this problem for regular basic processing (e.g., data selection & plotting with very few processing), but in this case, we chained quite a lot of different operations (expansions of dimensions, dataset concatenations, data selections, means, std, min/max, new dimension expansion...). Using the exact same processing on less data (e.g., one year) will only trigger a warning from Dask, telling me that the graph is huge (between 800MB and 1.3GB) and that it will be a performance issue, suggesting some good practice as well, but it will run. So, my current workaround is to just to do that: reduce the amount of data I'm processing at the same time (every year). I guess several intermediate .compute() would help as well, but considering the amount of data we're talking about, it's not an option.
I don't think it's a bug... but I also don't think it's the behaviour we want from xarray. We should be able to transmit to dask whatever dataset and get it processed. Xarray should be able to split the graphs better so it doesn't reach the limitation from dask or msgpack.
How to reproduce the problem
Below is the minimum code. Be aware that to make this code represent reality, it has to generate a huge amount of (random) data (200GB+). The compute at the end will use a lot of memory. This code is made to run on a HPC with 768GB of ram. I cannot really make it smaller, as I think that the core of the problem is that I'm processing a huge amount of data.
EDIT. There is the file needed to run this code: data_structure.zip. A version without this file is available below, but the error is a bit different.
Minimum code
``` import xarray as xr import numpy as np import pandas as pd from distributed import Client import os client = Client(n_workers=64, threads_per_worker=2) # Generate fake data generateData = False data_structure = xr.open_dataset('data_structure.nc') x = data_structure.x.values y = data_structure.y.values z = data_structure.z.values for ids, source in enumerate(['exp1_case1', 'exp2_case1', 'exp1_case2', 'exp2_case2', 'exp1_case3', 'exp2_case3']): # As the problem comes with lazy operation, a sampling dataset needs to be generated... # This generates about 200GB of random data. The dimensions are the same as what I use in the real case if generateData: os.makedirs(source, exist_ok=True) for idt, tt in enumerate(pd.date_range('1971-01-01','1972-12-31', freq='1SME')): ds = xr.Dataset( coords = {'z': z, 'x': x, 'y': y, 't': [tt], }, data_vars = { 'po4': (('t', 'z', 'y', 'x'), np.random.rand(1, len(z), len(y), len(x))), 'no3': (('t', 'z', 'y', 'x'), np.random.rand(1, len(z), len(y), len(x))), 'nh4': (('t', 'z', 'y', 'x'), np.random.rand(1, len(z), len(y), len(x))), 'oxy': (('t', 'z', 'y', 'x'), np.random.rand(1, len(z), len(y), len(x))), 'si': (('t', 'z', 'y', 'x'), np.random.rand(1, len(z), len(y), len(x))), }, ) outname = f'{source}/data_{idt}.nc' if os.path.exists(outname): continue ds.to_netcdf(outname) # .. So it can be loaded with xr.open_mfdataset() ds = xr.open_mfdataset(f'{source}/data_*.nc').assign_coords(suffix='suffix').assign_coords(source=source).expand_dims('suffix').expand_dims('source') if ids == 0: data = ds else: data = xr.concat([data, ds], dim='source') # Start processing on data # We get the sub-surface and bottom data. The index of the bottom varies in space. out = xr.Dataset() lastindex = xr.DataArray(name='index', coords={'x': x, 'y': y}, data=np.random.randint(0,10, size=(len(x),len(y))), ) for varname in data.data_vars: print(f'Extracting the data for {varname}') if varname in [ 'no3']: out['din'] = ((data['nh4'].rename('din') + data['no3'].rename('din'))).isel(z=[0,1,2]).mean('z') out['dinb'] = (data['nh4'].rename('dinb').isel(z=lastindex) + data['no3'].rename('dinb').isel(z=lastindex)) elif varname in ['po4']: out['dip'] = (data['po4']).isel(z=[0,1,2]).mean('z') out['dipb'] = data['po4'].isel(z=lastindex) elif varname in ['oxy']: out['o2b'] = data['oxy'].isel(z=lastindex) elif varname in ['si']: out['si'] = (data['si']).isel(z=[0,1,2]).mean('z') out['sib'] = data['si'].isel(z=lastindex) # Now we start the computing ensemble means on the data for idx, scenario in enumerate(['case1', 'case2', 'case3']): ds_mean = out.mean('source', skipna=True) ds_std = out.std('source', skipna=True) ds_min = out.min('source', skipna=True) ds_max = out.max('source', skipna=True) da_tokeep = xr.concat( (ds_mean, ds_std, ds_min, ds_max), dim='stats').assign_coords(stats=['mean','std','min','max']).assign_coords(source=[scenario]) print('working on da to be kept for scenario', scenario) if idx == 0: ds_ens = da_tokeep else: ds_ens = xr.concat( [ds_ens, da_tokeep], dim = 'source') # And the final compute() that crashes ds_ens.compute() ```Error messages
``` /home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/distributed/client.py:3371: UserWarning: Sending large graph of size 8.6 0 GiB. This may cause some slowdown. Consider loading the data with Dask directly or using futures or delayed objects to embed the data into the graph without repetition. See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information. warnings.warn( 2024-11-20 15:11:21,623 - distributed.protocol.core - CRITICAL - Failed to Serialize Traceback (most recent call last): File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/distributed/protocol/core.py", line 109, in dumps frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/msgpack/__init__.py", line 36, in packb return Packer(**kwargs).pack(o) ^^^^^^^^^^^^^^^^^^^^^^^^ File "msgpack/_packer.pyx", line 279, in msgpack._cmsgpack.Packer.pack File "msgpack/_packer.pyx", line 276, in msgpack._cmsgpack.Packer.pack File "msgpack/_packer.pyx", line 265, in msgpack._cmsgpack.Packer._pack File "msgpack/_packer.pyx", line 232, in msgpack._cmsgpack.Packer._pack_inner File "msgpack/_packer.pyx", line 265, in msgpack._cmsgpack.Packer._pack File "msgpack/_packer.pyx", line 213, in msgpack._cmsgpack.Packer._pack_inner File "msgpack/_packer.pyx", line 265, in msgpack._cmsgpack.Packer._pack File "msgpack/_packer.pyx", line 232, in msgpack._cmsgpack.Packer._pack_inner File "msgpack/_packer.pyx", line 265, in msgpack._cmsgpack.Packer._pack File "msgpack/_packer.pyx", line 189, in msgpack._cmsgpack.Packer._pack_inner ValueError: bytes object is too large 2024-11-20 15:11:21,626 - distributed.comm.utils - ERROR - bytes object is too large Traceback (most recent call last): File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/distributed/comm/utils.py", line 34, in _to_frames return list(protocol.dumps(msg, **kwargs)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/distributed/protocol/core.py", line 109, in dumps frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/msgpack/__init__.py", line 36, in packb return Packer(**kwargs).pack(o) ^^^^^^^^^^^^^^^^^^^^^^^^ File "msgpack/_packer.pyx", line 279, in msgpack._cmsgpack.Packer.pack File "msgpack/_packer.pyx", line 276, in msgpack._cmsgpack.Packer.pack File "msgpack/_packer.pyx", line 265, in msgpack._cmsgpack.Packer._pack File "msgpack/_packer.pyx", line 232, in msgpack._cmsgpack.Packer._pack_inner File "msgpack/_packer.pyx", line 265, in msgpack._cmsgpack.Packer._pack File "msgpack/_packer.pyx", line 213, in msgpack._cmsgpack.Packer._pack_inner File "msgpack/_packer.pyx", line 265, in msgpack._cmsgpack.Packer._pack File "msgpack/_packer.pyx", line 232, in msgpack._cmsgpack.Packer._pack_inner File "msgpack/_packer.pyx", line 265, in msgpack._cmsgpack.Packer._pack File "msgpack/_packer.pyx", line 189, in msgpack._cmsgpack.Packer._pack_inner ValueError: bytes object is too large 2024-11-20 15:11:21,631 - distributed.batched - ERROR - Error in batched write Traceback (most recent call last): File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/distributed/batched.py", line 115, in _background_send nbytes = yield coro ^^^^^^^^^^ File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/tornado/gen.py", line 766, in run value = future.result() ^^^^^^^^^^^^^^^ File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/distributed/comm/tcp.py", line 264, in write frames = await to_frames( ^^^^^^^^^^^^^^^^ File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/distributed/comm/utils.py", line 48, in to_frames return await offload(_to_frames) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/distributed/utils.py", line 1507, in run_in_executor_with_contex t return await loop.run_in_executor( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/XXXXXX/.conda/envs/d312xup/lib/python3.12/site-packages/distributed/utils.py", line 1508, inEnvironment
``` INSTALLED VERSIONS ------------------ commit: None python: 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] python-bits: 64 OS: Linux OS-release: 5.14.0-427.37.1.el9_4.x86_64 machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.14.3 libnetcdf: 4.9.2 xarray: 2024.10.0 pandas: 2.2.3 numpy: 2.1.3 scipy: 1.14.1 netCDF4: 1.7.1 pydap: None h5netcdf: 1.4.1 h5py: 3.12.1 zarr: None cftime: 1.6.4 nc_time_axis: None iris: None bottleneck: None dask: 2024.11.2 distributed: 2024.11.2 matplotlib: 3.9.2 cartopy: 0.24.0 seaborn: 0.13.2 numbagg: None fsspec: 2024.10.0 cupy: None pint: None sparse: None flox: None numpy_groupies: None setuptools: 75.5.0 pip: 24.3.1 conda: None pytest: None mypy: None IPython: None sphinx: None ```