pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.55k stars 1.07k forks source link

DataArray.set_index can add new dimensions that are absent on the underlying array. #9278

Open max-sixty opened 1 month ago

max-sixty commented 1 month ago

What is your issue?

I know we've done lots of great work on indexes. But I worry that we've made some basic things confusing. (Or at least I'm confused, very possibly I'm being slow this morning?)

A colleague asked how to group by multiple coords. I told them to set_index(foo=coord_list).groupby("foo"). But that seems to work inconsistently.


Here it works great:

da = xr.DataArray(
    np.array([1, 2, 3, 0, 2, np.nan]),
    dims="d",
    coords=dict(
        labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])),
        labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])),
    ),
)
da

<xarray.DataArray (d: 6)> Size: 48B
array([ 1.,  2.,  3.,  0.,  2., nan])
Coordinates:
    labels1  (d) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
    labels2  (d) <U1 24B 'x' 'y' 'z' 'z' 'y' 'x'
Dimensions without coordinates: d

Then make the two coords into a multiindex along d and group by d — we successfully get a value for each of the three values on d:

da.set_index(d=['labels1', 'labels2']).groupby('d').mean()

<xarray.DataArray (d: 3)> Size: 24B
array([1. , 2. , 1.5])
Coordinates:
  * d        (d) object 24B ('a', 'x') ('b', 'y') ('c', 'z')

But here it doesn't work as expected:

da = xr.DataArray(
    np.array([1, 2, 3, 0, 2, np.nan]),
    dims="time",
    coords=dict(
        time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)),
        labels1=("time", np.array(["a", "b", "c", "c", "b", "a"])),
        labels2=("time", np.array(["x", "y", "z", "z", "y", "x"])),
    ),
)
>>> da

<xarray.DataArray (time: 6)> Size: 48B
array([ 1.,  2.,  3.,  0.,  2., nan])
Coordinates:
  * time     (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
    labels1  (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
    labels2  (time) <U1 24B 'x' 'y' 'z' 'z' 'y' 'x'
reindexed = da.set_index(combined=['labels1', 'labels2'])
reindexed

<xarray.DataArray (time: 6)> Size: 48B
array([ 1.,  2.,  3.,  0.,  2., nan])
Coordinates:
  * time      (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
  * combined  (combined) object 48B MultiIndex
  * labels1   (combined) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
  * labels2   (combined) <U1 24B 'x' 'y' 'z' 'z' 'y' 'x'

Then we try grouping by combined, and we get a value for every value of combined and time?

reindexed.groupby('combined').mean()

<xarray.DataArray (time: 6, combined: 3)> Size: 144B
array([[ 1.,  1.,  1.],
       [ 2.,  2.,  2.],
       [ 3.,  3.,  3.],
       [ 0.,  0.,  0.],
       [ 2.,  2.,  2.],
       [nan, nan, nan]])
Coordinates:
  * time      (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
  * combined  (combined) object 24B ('a', 'x') ('b', 'y') ('c', 'z')

I'm guessing the reason is that combined (combined) object 48B MultiIndex is treated as a dimension?


To the extent this sort of point is correct and it's not just me misunderstanding: I know we've done some really great work recently on expanding xarray forward. But one of the main reasons I fell in love with xarray 10 years ago was how explicit the API was relative to pandas. And I worry that this sort of thing sets us back a bit.

dcherian commented 1 month ago

As an aside, the API isn't great but this works in flox (I think)

import flox.xarray

flox.xarray.xarray_reduce(da, "labels1", "labels2", func="mean")
image
dcherian commented 1 month ago

In your example

  1. I think you need to change the dimension time to combined on the array too. To me, It's not obvious why Xarray should do this automatically, but I'm also quite confused by set_index on a non-existing dimension. Seems wild.
  2. The bizarre broadcasting behaviour is a side-effect of concat being too permissive: https://github.com/pydata/xarray/issues/8778 https://github.com/pydata/xarray/issues/2145 which we should fix.
max-sixty commented 1 month ago

I think you need to change the dimension time to combined on the array too. To me, It's not obvious why Xarray should do this automatically, but I'm also quite confused by set_index on a non-existing dimension. Seems wild.

Yes, possibly we should raise an error on this?

Possibly our indexing work means we can now create multiple indexes on a dimension, so we want to be able to .set_index(combined=['labels1', 'labels2']) even though combined isn't a dimension. But the rest of the library hasn't caught up, so we get this incoherent groupby behavior?

Regardless, I'm not sure what the (combined) in * combined (combined) object 48B MultiIndex means given we don't have a combined dimension...

dcherian commented 1 month ago

I'm not sure what the (combined) in * combined (combined) object 48B MultiIndex means given we don't have a combined dimension

You've created a new variable named combined with dimension name combined and assigned an index to combined. Really what you want is a new variable combined with dimension name time. I don't know that there's an ergonomic way of doing that AND stacking at the same time. da.stack expects to stack dimensions.

max-sixty commented 1 month ago

You've created a new variable named combined with dimension name combined and assigned an index to combined.

But there's no dimension named combined on the data array! From above:

But it's not a dimension, the array lists <xarray.DataArray (time: 6)> Size: 48B as the dimensions. What's a good mental model for this?

dcherian commented 1 month ago

AH now I see. Yes that's a bug.