xarray-contrib / flox

Fast & furious GroupBy operations for dask.array
https://flox.readthedocs.io
Apache License 2.0
123 stars 16 forks source link

Roundoff error between flox and brute force aggregations #398

Closed riley-brady closed 1 week ago

riley-brady commented 1 week ago

Hi guys, love this package. I'm migrating some old aggregation code over to flox and noticed that there is some roundoff inconsistency between brute-force aggregations and using flox.

Problem: I am using flox to take a raster of integers as a mask of a variety of different groups I want to aggregate over. When I brute force it (just loop through the unique mask integers and create a .where() mask and aggregate over the pixels), I get slightly different numbers than when I do this with flox.

Package Versions:

Code:

import numpy as np
import xarray as xr
import flox.xarray

def generate_data():
    """Fixed test data generation."""
    np.random.seed(333)

    da = xr.DataArray(np.random.random((100, 100)), dims=["x", "y"]).rename("data")
    mask1 = np.random.choice([1, 2, 3, 4, 5], size=(100, 100))
    mask2 = np.random.choice([6, 7, 8, 9, 10], size=(100, 100))

    da = da.assign_coords(
        {"label1": (("x", "y"), mask1), "label2": (("x", "y"), mask2)}
    )
    return da, mask1, mask2

def brute_force_method(da, agg="sum"):
    """Looping through both masks and applying conditions directly on raster grid."""
    _, mask1, mask2 = generate_data()

    outer = []
    for m1 in [1, 2, 3, 4, 5]:
        inner = []
        for m2 in [6, 7, 8, 9, 10]:
            mask = (mask1 == m1) & (mask2 == m2)
            masked_data = da.where(mask)
            _result = getattr(masked_data, agg)()
            inner.append(_result)
        inner = xr.concat(inner, dim="label2")
        outer.append(inner)
    outer = xr.concat(outer, dim="label1")
    outer = outer.assign_coords(dict(label1=[1, 2, 3, 4, 5], label2=[6, 7, 8, 9, 10]))
    return outer

def flox_method(da, agg="sum"):
    """Using flox for same problem"""
    return flox.xarray.xarray_reduce(da, "label1", "label2", func=agg)

Execution:

da, *_ = generate_data()
bf_result = brute_force(da)
flox_result = flox_method(da)

print(flox_result - bf_result)
print((flox_result - bf_result)/bf_result * 100)

Differences are on the order of 1E-5.

Thoughts:

dcherian commented 1 week ago

I can't reproduce on numpy 2 even with xr.set_options(use_bottleneck=False, use_numbagg=False)

<xarray.DataArray 'data' (label1: 5, label2: 5)> Size: 200B
array([[ 2.84217094e-14,  1.42108547e-13,  2.27373675e-13,
         8.52651283e-14,  0.00000000e+00],
       [-1.98951966e-13,  0.00000000e+00, -2.84217094e-14,
        -2.84217094e-14,  1.13686838e-13],
       [-2.84217094e-14,  1.13686838e-13, -5.68434189e-14,
        -5.68434189e-14,  2.84217094e-14],
       [ 1.42108547e-13,  0.00000000e+00, -2.84217094e-14,
        -5.68434189e-14, -8.52651283e-14],
       [-8.52651283e-14, -1.42108547e-13,  2.84217094e-14,
         2.84217094e-14,  2.84217094e-14]])
Coordinates:
  * label1   (label1) int64 40B 1 2 3 4 5
  * label2   (label2) int64 40B 6 7 8 9 10

Also note that you can now use Xarray directly :) https://xarray.dev/blog/multiple-groupers

riley-brady commented 1 week ago

Isn't this reproducing it since you have roundoff error there rather than all zeroes?

Also note that you can now use Xarray directly :) https://xarray.dev/blog/multiple-groupers

Incredible, thank you!

EDIT: Some quick feedback. I'm finding with xarray groupers that I'm desiring the expected_groups kwarg from flox. If I just use a UniqueGrouper() for categorical grouping across the unique integers on my grid, it crashes my cluster. That's what brought me to flox to begin with -- using .groupby() on a integer mask grid is killing the cluster since it's likely running something like np.unique() under the hood. Since I can pre-determine the expected groups with flox, it's running substantially faster. Is this something I should raise over at xarray? Would love to have the direct xarray API for ease of use, and I assume that doing weighted aggregations is easier there so I can access .weighted().

dcherian commented 1 week ago

OK I was worried by your statement: "Differences are on the order of 1E-5." which is too large. This level of floating point inaccuracy is hard to fix but I can take a look. numpy sum can use pairwise summation, but the vectorized grouped-sum uses bincount which isn't as fancy, so there will be unavoidable precision differences.

dcherian commented 1 week ago

just use a UniqueGrouper() for categorical grouping across the unique integers on my grid,

How many unique groups do you have? can you share a dummy example?

since it's likely running something like np.unique() under the hood.

Yes, it is. I've been meaning to port expected_groups to UniqueGrouper, but haven't had the time to do so, and won't get to it for a while.

riley-brady commented 1 week ago

OK I was worried by your statement: "Differences are on the order of 1E-5." which is too large. This level of floating point inaccuracy is hard to fix but I can take a look. numpy sum can use pairwise summation, but the vectorized grouped-sum uses bincount which isn't as fancy, so there will be unavoidable precision differences.

That makes sense. Just tried at scale with my full problem size (1km global grid with 1,134 unique groups across three levels of aggregation) and am seeing maximum deviation of 1E-4%.

How many unique groups do you have? can you share a dummy example?

As in above... a ton. I can't even run np.unique() on a single mask with just 18 unique groups. I had tried this a few weeks ago when testing generating the grid.

Here's an example just with the single mask being applied with 18 unique groups for aggregation. Have a 50-worker dask cluster with 8CPU and 32GB RAM per worker and the np.unique() or just implicit UniqueGrouper call totally freezes it up. The equivalent flox call is done in <2 seconds.

import flox.xarray
import xarray as xr
import numpy as np
import dask.array as da

np.random.seed(123)

# Simulating 1km global grid
lat = np.linspace(-89.1, 89.1, 21384)
lon = np.linspace(-180, 180, 43200)

# Simulating data we'll be aggregating
data = da.random.random((lat.size, lon.size), chunks=(3600, 3600))
data = xr.DataArray(data, dims=['lat', 'lon'], coords={'lat': lat, 'lon': lon})

# Simulating 18 unique groups on the grid to aggregate over
integer_mask = da.random.choice(np.arange(1, 19), size=(lat.size, lon.size), chunks=(3600, 3600))
integer_mask = xr.DataArray(integer_mask, dims=['lat', 'lon'], coords={'lat': lat, 'lon': lon})

# Add as coordinate
data = data.assign_coords(dict(label1=integer_mask))

# Try with groupby (usually will spike scheduler memory, crash cluster, etc.). Haven't done a lot
# of looking at what's going on to wreck the cluster, just get impatient and give up.
# gb = data.groupby("label1")

# Versus, with expected groups. Runs extremely quickly to set up graph + execute.
res = flox.xarray.xarray_reduce(data, "label1", func="mean", skipna=True, expected_groups=np.arange(1, 19))
dcherian commented 1 week ago

pd.unique is a great substitute usually. You can also dask.array it.


re your issue: yes Xarray doesn't support grouping by a dask array yet, it will eagerly compute it and then find uniques. I need to port over that part of flox too :) BUT good to see that such features are useful!

dcherian commented 1 week ago

can you add your nice example to https://github.com/pydata/xarray/issues/2852 please?

dcherian commented 1 week ago

d am seeing maximum deviation of 1E-4%.

wild, can you confirm that this is unchanged on numpy 2 please

riley-brady commented 1 week ago

Thanks for the tip on pd.unique! Yes... the expected_groups and everything with xarray_reduce is fantastic. Thank you! Just will be working to try to set up some weighted aggregations (e.g. weighted mean) via the Custom Aggregation API on flox.

can you add your nice example to https://github.com/pydata/xarray/issues/2852 please?

Done!

wild, can you confirm that this is unchanged on numpy 2 please

Just ran this with numpy==2.0.2 and have the same result.

We can actually reproduce it with my MVE above.

results = []
for int_mask in np.arange(1, 19):
    masked_data = data.where(integer_mask == int_mask)
    res = masked_data.mean()
    results.append(res)

results = xr.concat(results, dim='int_val')
results = results.assign_coords(dict(int_val=np.arange(1, 19)))

# Compare to flox
b = flox.xarray.xarray_reduce(data, "label1", func="mean", skipna=True, expected_groups=np.arange(1, 19))
b = b.compute()

My "brute force" solution looks like:

array([0.49997059, 0.50000048, 0.50000736, 0.50009579, 0.49996396,
       0.50001181, 0.49997802, 0.49999938, 0.49992536, 0.49997482,
       0.49998581, 0.50000708, 0.50005436, 0.50001996, 0.50004166,
       0.5000194 , 0.50000429, 0.50000408])

The flox solution looks like:

array([0.49997059, 0.50000048, 0.50000736, 0.50009579, 0.49996396,
       0.50001181, 0.49997802, 0.49999938, 0.49992536, 0.49997482,
       0.49998581, 0.50000708, 0.50005436, 0.50001996, 0.50004166,
       0.5000194 , 0.50000429, 0.50000408])

For % error of

array([ 7.77201837e-14,  1.99839954e-13, -1.33224803e-13,  2.88602698e-13,
       -3.33090918e-13, -1.55427554e-13,  5.55135915e-14, -4.99600978e-13,
        1.88766094e-13,  4.10803207e-13, -3.88589086e-13,  3.77470481e-13,
       -1.11010234e-13, -2.22035741e-13,  1.11013052e-13, -6.66107975e-14,
        4.21881132e-13,  3.55268471e-13])

EDIT: Original answer had an error in my code. So here we're at very reasonable roundoff error.

dcherian commented 1 week ago

Just will be working to try to set up some weighted aggregations (e.g. weighted mean) via the Custom Aggregation API on flox.

this probably won't fit well. It may be better to just do (weights * ds).groupby(...).agg() / (weights.groupby(...).sum() or something

Original answer had an error in my code. So here we're at very reasonable roundoff error.

OK great! I'm not sure we can do much better