RhodiumGroup / rhg_compute_tools

Tools for using compute.rhg.com and compute.impactlab.org
MIT License
1 stars 4 forks source link

add xarray_from_delayed #12

Closed delgadom closed 6 years ago

delgadom commented 6 years ago

TODO

Prototype

import xarray as xr
import dask.array
import dask.distributed as dd

def dataarrays_from_delayed(futures, client=None):
    '''
    Returns a list of xarray dataarrays from a list of futures of dataarrays

    Parameters
    ----------
    futures : list
        list of :py:class:`dask.delayed.Future` objects holding
        :py:class:`xarray.DataArray` objects.
    client : object, optional
        :py:class:`dask.distributed.Client` to use in gathering
        metadata on futures. If not provided, client is inferred
        from context.

    Returns
    -------
    arrays : list
        list of :py:class:`xarray.DataArray` objects with
        :py:class:`dask.array.Array` backends.

    '''

    if client is None:
        client = dd.get_client()

    delayed_arrays = client.map(lambda x: x.data, futures)

    dask_array_metadata = client.gather(
        client.map(lambda x: (x.data.shape, x.data.dtype), futures))

    dask_arrays = [
        dask.array.from_delayed(delayed_arrays[i], *dask_array_metadata[i])
        for i in range(len(futures))]

    array_metadata = client.gather(
        client.map(
            lambda x: {'dims': x.dims, 'coords': x.coords, 'attrs': x.attrs},
            futures))

    data_arrays = [
        xr.DataArray(dask_arrays[i], **array_metadata[i])
        for i in range(len(futures))]

    return data_arrays

def dataarray_from_delayed(futures, dim=None, client=None):
    '''
    Returns a DataArray from a list of futures of dataarrays concatenated along ``dim``

    Parameters
    ----------
    futures : list
        list of :py:class:`dask.delayed.Future` objects holding
        :py:class:`xarray.DataArray` objects.
    dim : str, optional
        dimension along which to concat :py:class:`xarray.DataArray`.
        Inferred by default.
    client : object, optional
        :py:class:`dask.distributed.Client` to use in gathering
        metadata on futures. If not provided, client is inferred
        from context.

    Returns
    -------
    array : object
        :py:class:`xarray.DataArray` concatenated along ``dim`` with
        a :py:class:`dask.array.Array` backend.
    '''

    data_arrays = dataarrays_from_delayed(futures, client=client)
    da = xr.concat(data_arrays, dim=dim)

    return da

def datasets_from_delayed(futures, client=None):
    '''
    Returns a list of xarray datasets from a list of futures of datasets

    Parameters
    ----------
    futures : list
        list of :py:class:`dask.delayed.Future` objects holding
        :py:class:`xarray.Dataset` objects.
    client : object, optional
        :py:class:`dask.distributed.Client` to use in gathering
        metadata on futures. If not provided, client is inferred
        from context.

    Returns
    -------
    arrays : list
        list of :py:class:`xarray.Dataset` objects with
        :py:class:`dask.array.Array` backends for each variable.
    '''

    if client is None:
        client = dd.get_client()

    data_var_keys = client.gather(client.map(lambda x: list(x.data_vars.keys()), futures))

    delayed_arrays = [
        {k: client.submit(lambda x: x[k].data, futures[i]) for k in data_var_keys[i]}
        for i in range(len(futures))]

    dask_array_metadata = [
        {k: client.submit(lambda x: (x[k].data.shape, x[k].data.dtype), futures[i]).result() for k in data_var_keys[i]}
        for i in range(len(futures))]

    dask_data_arrays = [
        {k: dask.array.from_delayed(delayed_arrays[i][k], *dask_array_metadata[i][k]) for k in data_var_keys[i]}
        for i in range(len(futures))]

    array_metadata = [
            {k: client.submit(
                lambda x: {'dims': x[k].dims, 'coords': x[k].coords, 'attrs': x[k].attrs},
                futures[i]).result()
            for k in data_var_keys[i]}
        for i in range(len(futures))]

    data_arrays = [
        {k: xr.DataArray(dask_data_arrays[i][k], **array_metadata[i][k]) for k in data_var_keys[i]}
        for i in range(len(futures))]

    datasets = [xr.Dataset(arr) for arr in data_arrays]

    dataset_metadata = client.gather(
        client.map(lambda x: x.attrs, futures))

    for i in range(len(futures)):
        datasets[i].attrs.update(dataset_metadata[i])

    return datasets

def dataset_from_delayed(futures, dim=None, client=None):
    '''
    Returns an :py:class:`xarray.Dataset` from a list of futures of datasets concatenated along ``dim``

    Parameters
    ----------
    futures : list
        list of :py:class:`dask.delayed.Future` objects holding
        :py:class:`xarray.Dataset` objects.
    dim : str, optional
        dimension along which to concat :py:class:`xarray.Dataset`.
        Inferred by default.
    client : object, optional
        :py:class:`dask.distributed.Client` to use in gathering
        metadata on futures. If not provided, client is inferred
        from context.

    Returns
    -------
    array : object
        :py:class:`xarray.Dataset` concatenated along ``dim`` with
        :py:class:`dask.array.Array` backends for each variable.
    '''

    datasets = datasets_from_delayed(futures, client=client)
    ds = xr.concat(datasets, dim=dim)

    return ds