[x] pull the client using actual dask tools, rather than having to paste these functions into a notebook with the client in scope
[x] improve docs
[x] write tests
[ ] maybe test it out and make a PR to xarray?
Prototype
import xarray as xr
import dask.array
import dask.distributed as dd
def dataarrays_from_delayed(futures, client=None):
'''
Returns a list of xarray dataarrays from a list of futures of dataarrays
Parameters
----------
futures : list
list of :py:class:`dask.delayed.Future` objects holding
:py:class:`xarray.DataArray` objects.
client : object, optional
:py:class:`dask.distributed.Client` to use in gathering
metadata on futures. If not provided, client is inferred
from context.
Returns
-------
arrays : list
list of :py:class:`xarray.DataArray` objects with
:py:class:`dask.array.Array` backends.
'''
if client is None:
client = dd.get_client()
delayed_arrays = client.map(lambda x: x.data, futures)
dask_array_metadata = client.gather(
client.map(lambda x: (x.data.shape, x.data.dtype), futures))
dask_arrays = [
dask.array.from_delayed(delayed_arrays[i], *dask_array_metadata[i])
for i in range(len(futures))]
array_metadata = client.gather(
client.map(
lambda x: {'dims': x.dims, 'coords': x.coords, 'attrs': x.attrs},
futures))
data_arrays = [
xr.DataArray(dask_arrays[i], **array_metadata[i])
for i in range(len(futures))]
return data_arrays
def dataarray_from_delayed(futures, dim=None, client=None):
'''
Returns a DataArray from a list of futures of dataarrays concatenated along ``dim``
Parameters
----------
futures : list
list of :py:class:`dask.delayed.Future` objects holding
:py:class:`xarray.DataArray` objects.
dim : str, optional
dimension along which to concat :py:class:`xarray.DataArray`.
Inferred by default.
client : object, optional
:py:class:`dask.distributed.Client` to use in gathering
metadata on futures. If not provided, client is inferred
from context.
Returns
-------
array : object
:py:class:`xarray.DataArray` concatenated along ``dim`` with
a :py:class:`dask.array.Array` backend.
'''
data_arrays = dataarrays_from_delayed(futures, client=client)
da = xr.concat(data_arrays, dim=dim)
return da
def datasets_from_delayed(futures, client=None):
'''
Returns a list of xarray datasets from a list of futures of datasets
Parameters
----------
futures : list
list of :py:class:`dask.delayed.Future` objects holding
:py:class:`xarray.Dataset` objects.
client : object, optional
:py:class:`dask.distributed.Client` to use in gathering
metadata on futures. If not provided, client is inferred
from context.
Returns
-------
arrays : list
list of :py:class:`xarray.Dataset` objects with
:py:class:`dask.array.Array` backends for each variable.
'''
if client is None:
client = dd.get_client()
data_var_keys = client.gather(client.map(lambda x: list(x.data_vars.keys()), futures))
delayed_arrays = [
{k: client.submit(lambda x: x[k].data, futures[i]) for k in data_var_keys[i]}
for i in range(len(futures))]
dask_array_metadata = [
{k: client.submit(lambda x: (x[k].data.shape, x[k].data.dtype), futures[i]).result() for k in data_var_keys[i]}
for i in range(len(futures))]
dask_data_arrays = [
{k: dask.array.from_delayed(delayed_arrays[i][k], *dask_array_metadata[i][k]) for k in data_var_keys[i]}
for i in range(len(futures))]
array_metadata = [
{k: client.submit(
lambda x: {'dims': x[k].dims, 'coords': x[k].coords, 'attrs': x[k].attrs},
futures[i]).result()
for k in data_var_keys[i]}
for i in range(len(futures))]
data_arrays = [
{k: xr.DataArray(dask_data_arrays[i][k], **array_metadata[i][k]) for k in data_var_keys[i]}
for i in range(len(futures))]
datasets = [xr.Dataset(arr) for arr in data_arrays]
dataset_metadata = client.gather(
client.map(lambda x: x.attrs, futures))
for i in range(len(futures)):
datasets[i].attrs.update(dataset_metadata[i])
return datasets
def dataset_from_delayed(futures, dim=None, client=None):
'''
Returns an :py:class:`xarray.Dataset` from a list of futures of datasets concatenated along ``dim``
Parameters
----------
futures : list
list of :py:class:`dask.delayed.Future` objects holding
:py:class:`xarray.Dataset` objects.
dim : str, optional
dimension along which to concat :py:class:`xarray.Dataset`.
Inferred by default.
client : object, optional
:py:class:`dask.distributed.Client` to use in gathering
metadata on futures. If not provided, client is inferred
from context.
Returns
-------
array : object
:py:class:`xarray.Dataset` concatenated along ``dim`` with
:py:class:`dask.array.Array` backends for each variable.
'''
datasets = datasets_from_delayed(futures, client=client)
ds = xr.concat(datasets, dim=dim)
return ds
TODO
Prototype