larray-project / larray

N-dimensional labelled arrays in Python
https://larray.readthedocs.io/
GNU General Public License v3.0
8 stars 6 forks source link

add Array.extract? #1001

Open gdementen opened 2 years ago

gdementen commented 2 years ago

Our users relatively often need to extract a few labels from an axis as another (possibly existing) axis with different labels. Currently, they usually use set_axes for this, as in:

>>> arr = ndtest(5)
>>> arr
a  a0  a1  a2  a3  a4
    0   1   2   3   4
>>> b = Axis('b=b1,b2')
>>> arr['a1,a3'].set_axes('a', b)
b   b1   b2
   1.0  3.0

But given that the axis definition is often far from the set_axes call (even in a different file), there is a high risk (and I have witnessed this actually happening a few times) to get the labels order wrong, which is a pity given it is one of the missions of LArray to prevent that class of errors.

The alternative I recommend is to use set_labels with a map, but then the original labels are specified twice. Unsure if that is the reason, but our users are generally not very enthusiatic about this recommendation.

>>> arr['a1,a3'].set_labels({'a1': 'b1', 'a3': 'b2'}).rename('a', 'b')
b   b1   b2
   1.0  3.0

I wonder if introducing a new "extract" method would help with this:

>>> arr.extract({'a1': 'b1', 'a3': 'b2'}, 'b')
b   b1   b2
   1.0  3.0
>>> # works with a predefined axis too
>>> arr.extract({'a3': 'b2', 'a1': 'b1'}, b)
b   b1   b2
   1.0  3.0

Here is a quick and dirty implementation I did for testing:

def extract(array, label_map, axis=None):
    orig_keys = list(label_map.keys())
    subset = array.axes._guess_axis(orig_keys)
    old_axis = subset.axis
    array = array[subset].set_labels(old_axis, label_map)
    if axis is not None:
        array = array.rename(old_axis, axis)
        if isinstance(axis, Axis):
            array = array.reindex(axis, axis)
    return array

Another option, would be to generalize aggregate methods to be able to explicitly name the new aggregated axis (see #1002), which we probably need to implement anyway.

Currently, the above "extract" test can also be spelled like:

arr.sum('a1 >> b1;a3 >> b2').rename('a', 'b')

and it would be nice to be able to express it like this instead:

arr.sum(b='a1 >> b1;a3 >> b2')

OR (unsure which, or both):

arr.sum('b=a1 >> b1;a3 >> b2')

One final option would be to have an extract method using the same syntax than aggregate methods instead of a dict:

arr.extract('b=a1 >> b1;a3 >> b2')

In either case, this does not support the existing axis usecase. I have not found a way to express that nicely yet.

gdementen commented 1 year ago

If we implement this, it should also work for multiple axes at once:

>>> arr = ndtest((3, 4))
>>> arr.extract({'a': {'b1': 'a1', 'b2': 'a0'}, 'b': {'a1': 'b1', 'a3': 'b2'}})

Unsure if that would work for existing axes:

>>> b = Axis('b=b1,b2')
>>> arr.extract({'a': {'b1': 'a1', 'b2': 'a0'}, b: {'a1': 'b1', 'a3': 'b2'}})