larray-project / larray

N-dimensional labelled arrays in Python
https://larray.readthedocs.io/
GNU General Public License v3.0
8 stars 6 forks source link

add support for axes tuples in aggregates without axis= #1000

Open gdementen opened 2 years ago

gdementen commented 2 years ago

We support aggregating over multiple axes by passing multiple arguments to aggregate functions:

>>> arr = ndtest((2, 3, 4))
>>> arr.sum('a', 'c')
b  b0  b1   b2
   60  92  124
>>> arr.sum(arr.a, arr.c)
b  b0  b1   b2
   60  92  124

But sometimes, axes to aggregate are in a tuple or list (e.g. when writing a higher-level function with an "axis" argument which is passed to a lower-level/simple aggregate functions). In other words, I know this is not really a problem when users use the aggregate functions directly. But when creating a function, you currently have to write more (ugly) code than necessary or loose some generality/functionality.

>>> arr.sum(('a', 'c'))
ValueError: a is not a valid label for any axis
>>> arr.sum([arr.a, arr.c])
ValueError: [Axis(['a0', 'a1'], 'a'), Axis(['c0', 'c1', 'c2', 'c3'], 'c')] is not a valid label for any axis
>>> arr.sum((arr.a, arr.c))
TypeError: a has an invalid type (Axis) for a key

Currently, the solutions are to either use an explicit axis= argument, which works for both a single axis or a tuple of axes, but do not work for groups

>>> arr.sum(axis=(arr.a, arr.c))
b  b0  b1   b2
   60  92  124
>>> arr.sum(axis=('a', 'c'))
b  b0  b1   b2
   60  92  124

Or to use *axes, which works for both a tuple of axes and a tuple of groups but not when it is a single axis/group.

>>> axes = ('a', 'c')
>>> arr.sum(*axes)
b  b0  b1   b2
   60  92  124

I would like a solution which works in all cases: single axis, single group, tuple of axes, tuple of groups. Maybe list of axes/groups too but I am unsure it is worth it.

Supporting arr.sum(('a', 'c')) might not be possible (it is ambiguous with labels -- IIRC, this is the reason for the existence of the axis argument) but both arr.sum((arr.a, arr.c)) and arr.sum((arr.a[:], arr.c[:])) should work IMO. FWIW and IIRC, in some other method (I don't remember which though), I used a different trade-off: when there is an ambiguity between an axis name and an axis label, the axis "wins", and you have to use an explicit group in that case to target the label. We should find out which method it is and whether it makes sense or not to use the same tradeoff in both cases.

Related issue: #660

gdementen commented 2 years ago

I just stumbled on this problem again, with a slight twist: the higher level function (weighted_mean) was given a tuple with an axis and a tuple of groups: (X.gmp, (age[60:64], age[65:69], age[70:74], age[75:103])) and the workaround I had used so far broke.

gdementen commented 2 years ago

FWIW, my new workaround weighted_mean function is:

def weighted_mean(a, w, axis):
    # if axis is a tuple of non groups, use special syntax to allow it (see #1000)
    if isinstance(axis, tuple) and not all(isinstance(g, larray.Group) for g in axis):
        return (a * w).sum(*axis).divnot0(w.sum(*axis))
    return (a * w).sum(axis).divnot0(w.sum(axis))