pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.64k stars 1.09k forks source link

Rolling.argmin() and Rolling.argmax() over multiple dimensions does not work #6691

Open AlxLhrNc opened 2 years ago

AlxLhrNc commented 2 years ago

What is your issue?

I am working with xarray.core.rolling.DataArrayRolling objects and for some obscure reason, they raise the following error when the .argmin() and .argmax() methods are used:

window_size = {name: n for name in ['lat', 'lon']}
window = data_nc.rolling(window_size, center=True)
peak_min = window.argmin()
peak_max = window.argmax()

Traceback (most recent call last):
  Input In [48] in <cell line: 3>
  File ~\Installed_Programs\Anaconda3\envs\phd\lib\site-packages\xarray\core\rolling.py:122 in method
    return self._numpy_or_bottleneck_reduce(

  File ~\Installed_Programs\Anaconda3\envs\phd\lib\site-packages\xarray\core\rolling.py:541 in _numpy_or_bottleneck_reduce
    return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs)

  File ~\Installed_Programs\Anaconda3\envs\phd\lib\site-packages\xarray\core\rolling.py:428 in reduce
    result = windows.reduce(

  File ~\Installed_Programs\Anaconda3\envs\phd\lib\site-packages\xarray\core\dataarray.py:2696 in reduce
    var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)

  File ~\Installed_Programs\Anaconda3\envs\phd\lib\site-packages\xarray\core\variable.py:1804 in reduce
    data = func(self.data, axis=axis, **kwargs)

  File ~\Installed_Programs\Anaconda3\envs\phd\lib\site-packages\xarray\core\duck_array_ops.py:335 in f
    return func(values, axis=axis, **kwargs)

  File <__array_function__ internals>:180 in argmax

  File ~\Installed_Programs\Anaconda3\envs\phd\lib\site-packages\numpy\core\fromnumeric.py:1216 in argmax
    return _wrapfunc(a, 'argmax', axis=axis, out=out, **kwds)

  File ~\Installed_Programs\Anaconda3\envs\phd\lib\site-packages\numpy\core\fromnumeric.py:66 in _wrapfunc
    return _wrapit(obj, method, *args, **kwds)

  File ~\Installed_Programs\Anaconda3\envs\phd\lib\site-packages\numpy\core\fromnumeric.py:43 in _wrapit
    result = getattr(asarray(obj), method)(*args, **kwds)

TypeError: 'tuple' object cannot be interpreted as an integer

As far as I understand, these methods are included based on the xarray documentation, so it should work as .min(), .max() or .mean() are working fine: https://docs.xarray.dev/en/stable/generated/xarray.core.rolling.DataArrayRolling.argmax.html

Any insight on what I am doing wrong ?

keewis commented 2 years ago

for reference, here's the MCVE:

import xarray as xr

ds = xr.tutorial.open_dataset("air_temperature")
n = 3

window_size = {name: n for name in ["lat", "lon"]}
window = ds.rolling(window_size, center=True)
peak_min = window.argmin()
peak_max = window.argmax()

As far as I can tell, the reason is that np.argmin does not support more than one axis (the actual call is here).

We could try to work around this by wrapping argmin and argmax in a function that ravels / unravels the reduce dimensions.

AlxLhrNc commented 2 years ago

Thank you for upgrading this question. Using keewis suggestion:

We could try to work around this by wrapping argmin and argmax in a function that ravels / unravels the reduce dimensions.

I had a look into numpy documentation and came up with a work around like so:

import numpy as np, xarray as xr

arr = xr.DataArray(np.ones((5,6,7)))
arr[1,3,2], arr[3,1,4] = 0, 2

print('Values:',arr.min(), arr.max())

min_pos = np.concatenate(np.where(arr == arr.min()))
max_pos = np.concatenate(np.where(arr == arr.max()))
print('Indexes on nD:', min_pos, max_pos)

It is not perfect I suppose but it is doing the job for now.