pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.61k stars 1.08k forks source link

[FEATURE]: Add a replace method #6377

Open Huite opened 2 years ago

Huite commented 2 years ago

Is your feature request related to a problem?

If I have a DataArray of values:

da = xr.DataArray([0, 1, 2, 3, 4, 5])

And I'd like to replace to_replace=[1, 3, 5] by value=[10, 30, 50], there's no method da.replace(to_replace, value) to do this.

There's no easy way like pandas (https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.replace.html) to do this.

(Apologies if I've missed related issues, searching for "replace" gives many hits as the word is obviously used quite often.)

Describe the solution you'd like

da = xr.DataArray([0, 1, 2, 3, 4, 5])
replaced = da.replace([1, 3, 5], [10, 30, 50])
print(replaced)
<xarray.DataArray (dim_0: 6)>
array([ 0, 10,  2, 30,  4, 50])
Dimensions without coordinates: dim_0

I've had a try at a relatively efficient implementation below. I'm wondering whether it's a worthwhile addition to xarray?

Describe alternatives you've considered

Ignoring issues such as dealing with NaNs, chunks, etc., a simple dict lookup:

def dict_replace(da, to_replace, value):
    d = {k: v for k, v in zip(to_replace, value)}
    out = np.vectorize(lambda x: d.get(x, x))(da.values)
    return da.copy(data=out)

Alternatively, leveraging pandas:

def pandas_replace(da, to_replace, value):
    df = pd.DataFrame()
    df["values"] = da.values.ravel()
    df["values"].replace(to_replace, value, inplace=True)
    return da.copy(data=df["values"].values.reshape(da.shape))

But I also tried my hand at a custom implementation, letting np.unique do the heavy lifting:

def custom_replace(da, to_replace, value):
    # Use np.unique to create an inverse index
    flat = da.values.ravel()
    uniques, index = np.unique(flat, return_inverse=True)    
    replaceable = np.isin(flat, to_replace)

    # Create a replacement array in which there is a 1:1 relation between
    # uniques and the replacement values, so that we can use the inverse index
    # to select replacement values. 
    valid = np.isin(to_replace, uniques, assume_unique=True)
    # Remove to_replace values that are not present in da. If no overlap
    # exists between to_replace and the values in da, just return a copy.
    if not valid.any():
        return da.copy()
    to_replace = to_replace[valid]
    value = value[valid]

    replacement = np.zeros_like(uniques)
    replacement[np.searchsorted(uniques, to_replace)] = value

    out = flat.copy()
    out[replaceable] = replacement[index[replaceable]]
    return da.copy(data=out.reshape(da.shape))

Such an approach seems like it's consistently the fastest:

da = xr.DataArray(np.random.randint(0, 100, 100_000))
to_replace = np.random.choice(np.arange(100), 10, replace=False)
value = to_replace * 200

test1 = custom_replace(da, to_replace, value)
test2 = pandas_replace(da, to_replace, value)
test3 = dict_replace(da, to_replace, value)

assert test1.equals(test2)
assert test1.equals(test3)

# 6.93 ms ± 295 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit custom_replace(da, to_replace, value) 

# 9.37 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit pandas_replace(da, to_replace, value) 

# 26.8 ms ± 1.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit dict_replace(da, to_replace, value) 

With the advantage growing the number of values involved:

da = xr.DataArray(np.random.randint(0, 10_000, 100_000))
to_replace = np.random.choice(np.arange(10_000), 10_000, replace=False)
value = to_replace * 200

test1 = custom_replace(da, to_replace, value)
test2 = pandas_replace(da, to_replace, value)
test3 = dict_replace(da, to_replace, value)

assert test1.equals(test2)
assert test1.equals(test3)

# 21.6 ms ± 990 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit custom_replace(da, to_replace, value)
# 3.12 s ± 574 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit pandas_replace(da, to_replace, value)
# 42.7 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit dict_replace(da, to_replace, value)

In my real-life example, with a DataArray of approx 110 000 elements, with 60 000 values to replace, the custom one takes 33 ms, the dict one takes 135 ms, while pandas takes 26 s (!).

Additional context

In all cases, we need dealing with NaNs, checking the input, etc.:

def replace(da: xr.DataArray, to_replace: Any, value: Any):
    from xarray.core.utils import is_scalar

    if is_scalar(to_replace):
        if not is_scalar(value):
            raise TypeError("if to_replace is scalar, then value must be a scalar")
        if np.isnan(to_replace):
            return da.fillna(value) 
        else:
            return da.where(da != to_replace, other=value)
    else:
        to_replace = np.asarray(to_replace)
        if to_replace.ndim != 1:
            raise ValueError("to_replace must be 1D or scalar")
        if is_scalar(value):
            value = np.full_like(to_replace, value)
        else:
            value = np.asarray(value)
            if to_replace.shape != value.shape:
                raise ValueError(
                    f"Replacement arrays must match in shape. "
                    f"Expecting {to_replace.shape} got {value.shape} "
                )

    _, counts = np.unique(to_replace, return_counts=True)
    if (counts > 1).any():
        raise ValueError("to_replace contains duplicates")

    # Replace NaN values separately, as they will show up as separate values
    # from numpy.unique.
    isnan = np.isnan(to_replace)
    if isnan.any():
        i = np.nonzero(isnan)[0]
        da = da.fillna(value[i])

    # Use np.unique to create an inverse index
    flat = da.values.ravel()
    uniques, index = np.unique(flat, return_inverse=True)    
    replaceable = np.isin(flat, to_replace)

    # Create a replacement array in which there is a 1:1 relation between
    # uniques and the replacement values, so that we can use the inverse index
    # to select replacement values. 
    valid = np.isin(to_replace, uniques, assume_unique=True)
    # Remove to_replace values that are not present in da. If no overlap
    # exists between to_replace and the values in da, just return a copy.
    if not valid.any():
        return da.copy()
    to_replace = to_replace[valid]
    value = value[valid]

    replacement = np.zeros_like(uniques)
    replacement[np.searchsorted(uniques, to_replace)] = value

    out = flat.copy()
    out[replaceable] = replacement[index[replaceable]]
    return da.copy(data=out.reshape(da.shape))

It think it should be easy to use e.g. let it operate on the numpy arrays so e.g. apply_ufunc will work. The primary issue is whether values can be sorted; in such a case the dict lookup might be an okay fallback? I've had a peek at the pandas implementation, but didn't become much wiser.

Anyway, for your consideration! I'd be happy to submit a PR.

max-sixty commented 2 years ago

I agree this would be useful, and I've had to do similar things. It's the sort of area where pandas is stronger than xarray.

We might want a more specific name than replace; something that confers it's replacing values? Particularly if the method is on a Dataset as well as a DataArray.

@Huite thanks for the great proposal. Did you look at np.select? I think that might be faster than these and require less code.

Huite commented 2 years ago

Yeah I think maybe replace_values is better name. "search and replace values" is maybe how you'd describe it colloquially?remap is an option too, but I think many users won't have the right assocation with it (if they're coming from a less technical background).

I don't think you'd want to this with np.select. If I understand correctly, you'd have to broadcast for the number of values to replace. This work okay with a small number of replacement values, but not with 10 000 like in my example above (but my understanding might be lacking).

Having said that, there is a faster and much cleaner implementation using np.seachsorted on da instead.

def custom_replace2(da, to_replace, value):
    flat = da.values.ravel()

    sorter = np.argsort(to_replace)
    insertion = np.searchsorted(to_replace, flat, sorter=sorter)
    indices = np.take(sorter, insertion, mode="clip")
    replaceable = (to_replace[indices] == flat)

    out = flat.copy()
    out[replaceable] = value[indices[replaceable]]
    return da.copy(data=out.reshape(da.shape))

# For small example: 4.1 ms ± 144 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# For the larger example: # 14.4 ms ± 592 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit custom_replace2(da, to_replace, value)

This is equal to the implementation of remap in numpy-indexed (which is MIT-licensed): https://github.com/EelcoHoogendoorn/Numpy_arraysetops_EP

The key trick is the same, relying on sorting.

See e.g. also: https://stackoverflow.com/questions/16992713/translate-every-element-in-numpy-array-according-to-key

dcherian commented 2 years ago

See also #5048 though the discussion here is more thorough.

max-sixty commented 2 years ago

Nice find @dcherian .

So it sounds like there's consensus around something like replace_data / replace_values / update_values. If you'd still be up for putting together a PR, I think that would be very welcome.

You're right about np.select @Huite . The np.searchsorted solution looks v clever!

Jeitan commented 2 years ago

Thanks @dcherian for linking the other issue because that led me here. I'm all for this! Though I would like to add the consideration for doing this replacement in a coordinate, not just the data (parts of the suggested code like returning da.copy(data=out.reshape(da.shape)) won't work for that). Once they are accessed coordinates work very much like the data part, so hopefully making this general shouldn't be too hard?

Huite commented 2 years ago

@Jeitan

The coordinate is a DataArray as well, so the following would work:

# Example DataArray
da = xr.DataArray(np.ones((3, 3)), {"y": [50.0, 60.0, 70.0], "x": [1.0, 2.0, 3.0]}, ("y", "x"))

# Replace 50.0 and 60.0 by 5.0 and 6.0 in the y coordinate
da["y"] = da["y"].replace_values([50.0, 60.0], [5.0, 6.0])

Your example in the other issue mentions one of the ways you'd replace in pandas, but for a dataframe. With a dataframe, there's quite some flexibility:

df.replace({0: 10, 1: 100})
df.replace({'A': 0, 'B': 5}, 100)
df.replace({'A': {0: 100, 4: 400}})

I'd say the xarray counterpart of a Dataframe is a Dataset; the counterpart of a DataArray is a Series. Replacing the coordinates in a DataArray is akin to replacing the values of the index of a Series, which is apparently possible with series.rename(index={from: to}).

Other thoughts: some complexity comes in when implementing a replace_values method for a Dataset. I also think the pandas replace method signature is too complicated (scalars, lists, dicts, dicts of dicts, probably more?) and the docstring is quite extensive (https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.replace.html)

I think the question is what the signature should be. You could compare to reindex (https://xarray.pydata.org/en/stable/generated/xarray.Dataset.reindex.html) and have an "replacer" argument:

da = da.replace({"y": ([50.0, 60.0], [5.0, 6.0])})

da["y"] = da["y"].replace([50.0, 60.0], [5.0, 6.0])

The first one would also work for Datasets, but I personally prefer the second one for it's simplicity (and which is maybe closer to .where : https://xarray.pydata.org/en/stable/generated/xarray.DataArray.where.html).

Jeitan commented 2 years ago

@Huite Indeed, you are right that working with a coordinate is easy if it works for DataArrays ... this is a good example of my pandas-oriented brain not quite being used to xarray just yet (though I do love it).

Regarding signature options for a Dataset ... given the two examples you state, I also personally prefer the look of the second one. However, the first one can be extremely useful for more complicated replacement needs because the input dict can be assembled programmatically prior to the replace call, for doing replaces in several subset DataArrays. I think the second version would require looping of some sort, or multiple calls at the very least. For me, in my context of renaming on coordinates (the index or columns in a DataFrame context), I often have to modify many things in both axes, which I do using one dictionary.

I suppose it's a matter of preference and of ease of implementation ... since I'm not the one doing the coding, I shall definitely defer to others on the latter point!

RichardScottOZ commented 1 year ago

Thanks @Huite