pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.61k stars 1.08k forks source link

[proposal] concatenate by axis, ignore dimension names #3281

Open Hoeze opened 5 years ago

Hoeze commented 5 years ago

Hi, I wrote a helper function which allows to concatenate arrays like xr.combine_nested with the difference that it only supports xr.DataArrays, concatenates them by axis position similar to np.concatenate and overwrites all dimension names.

I often need this to combine very different feature types.

from typing import Union, Tuple, List
import numpy as np
import xarray as xr

def concat_by_axis(
        darrs: Union[List[xr.DataArray], Tuple[xr.DataArray]],
        dims: Union[List[str], Tuple[str]],
        axis: int = None,
        **kwargs
):
    """
    Concat arrays along some axis similar to `np.concatenate`. Automatically renames the dimensions to `dims`.
    Please note that this renaming happens by the axis position, therefore make sure to transpose all arrays
    to the correct dimension order.

    :param darrs: List or tuple of xr.DataArrays
    :param dims: The dimension names of the resulting array. Renames axes where necessary.
    :param axis: The axis which should be concatenated along
    :param kwargs: Additional arguments which will be passed to `xr.concat()`
    :return: Concatenated xr.DataArray with dimensions `dim`.
    """

    # Get depth of nested lists. Assumes `darrs` is correctly formatted as list of lists.
    if axis is None:
        axis = 0
        l = darrs
        # while l is a list or tuple and contains elements:
        while isinstance(l, List) or isinstance(l, Tuple) and l:
            # increase depth by one
            axis -= 1
            l = l[0]
        if axis == 0:
            raise ValueError("`darrs` has to be a (possibly nested) list or tuple of xr.DataArrays!")

    to_concat = list()
    for i, da in enumerate(darrs):
        # recursive call for nested arrays;
        # most inner call should have axis = -1,
        # most outer call should have axis = - depth_of_darrs
        if isinstance(da, list) or isinstance(da, tuple):
            da = concat_axis(da, dims=dims, axis=axis + 1, **kwargs)

        if not isinstance(da, xr.DataArray):
            raise ValueError("Input %d must be a xr.DataArray" % i)
        if len(da.dims) != len(dims):
            raise ValueError("Input %d must have the same number of dimensions as specified in the `dims` argument!" % i)

        # force-rename dimensions
        da = da.rename(dict(zip(da.dims, dims)))

        to_concat.append(da)

    return xr.concat(to_concat, dim=dims[axis], **kwargs)

Would it make sense to include this in xarray?

shoyer commented 5 years ago

Thanks for the suggestion!

Generally we try to keep all of xarray's functions "metadata aware", rather overriding metadata. So the API feels a little out of place to me.

Maybe you could share a little bit more about your use case for this?

Hoeze commented 5 years ago

Thanks for your answer @shoyer. OK, then this can be closed, since this function should actually remove metadata for me :)

For example, lets consider a dataset with:

Now I want to stick those side-by-side to get an array x_combined ("obs", "features") with features = features_1 + ... + features_n".

TomNicholas commented 5 years ago

@Hoeze is it possible to use xarray.stack for your example? e.g.

features_i = ("features_1", "features_2", ..., "features_n")
x_combined = original.stack(features=features_i)
Hoeze commented 5 years ago

@TomNicholas No, not really. Stack can be only used to combine multiple coordinates into one coordinate, e.g. data[x, y, z] -> stacked_data[a, z] with a as a multi-index of x and y.

In this case, we do not have shared data with coordinates to combine.
Instead, multiple independent DataArrays should be concatenated along some dimension.

The most similar methods to this one are xr.concat and xr.combine_nested. However, they do not allow to implicitly rename dimensions and force-delete non-shared metadata.