pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.65k stars 1.09k forks source link

A broadcasting sum for xarray.Dataset #6053

Open mjwillson opened 2 years ago

mjwillson commented 2 years ago

I've found it useful to have a version of Dataset.sum which sums variables in a way that's consistent with what would happen if they were broadcast to the full Dataset dimensions.

The difference is in what it does with variables that don't contain some of the dimensions it's asked to sum over: standard sum just ignores the summation over these dimensions for these variables, whereas a broadcasting_sum will multiply the variable by the product of sizes the missing dimensions, like so:

def broadcast_sum(dataset, dims):
  def broadcast_sum_var(var):
    present_sum_dims = [dim for dim in dims if dim in var.dims]
    non_present_sum_dims = [dim for dim in dims if dim not in var.dims]
    return var.sum(present_sum_dims) * np.prod([dataset.sizes[dim] for dim in non_present_sum_dims])
  return dataset.map(broadcast_sum_var)

This is consistent with mathematical sum notation, where the sum doesn't become a no-op just because the summand doesn't reference the index being summed over. E.g.:

$\sum_{n=1}^N x = N x$

I've found it useful when you need to do some broadcasting operations across different variables after the sum, and you want the summation done in a way that's consistent with the broadcasting logic that will be applied later.

Would you be open to adding this, and if so any preference how? (A separate method, an option to .sum ?)

dcherian commented 2 years ago

xr.broadcast(ds)[0].sum(dims) should do this.

We could add it here: https://xarray.pydata.org/en/latest/howdoi.html and to the docs under Aggregations

headtr1ck commented 2 years ago

See discussion in https://github.com/pydata/xarray/issues/6749

Maybe the current implementation of sum is not correct?

mjwillson commented 2 years ago

Re xr.broadcast(ds)[0].sum(dims) -- Thanks, that's neat and may be useful as a workaround, but it looks like it could incur significant extra CPU and RAM costs (tiling all variables to the full size in memory before summing over the tiled values)? Or is there some clever optimisation under the hood which would avoid this?

I also only wanted it to (behave as though it) broadcast the dims that are summed over, but this looks like it will broadcast all dims including those not summed over?

Overall I think it'd be better to have an option on sum (like missing_dim='broadcast' as suggested in #6749), rather than documenting a partial workaround like this, given the caveats attached to the workaround and that (to me at least) the broadcasting sum is more in keeping with the usual mathematical semantics of 'sum' than what 'sum' currently does.

dcherian commented 5 months ago

A more explicit API could be ds.broadcasting.sum()