data-apis / array-api-compat

Compatibility layer for NumPy to support the Array API
https://data-apis.org/array-api-compat/
MIT License
59 stars 19 forks source link

Helper functionality to work around how different libraries handle copies vs. mutation and/or views? #146

Open mdhaber opened 1 month ago

mdhaber commented 1 month ago

Follow-up to https://github.com/data-apis/array-api-compat/issues/144#issuecomment-2161491665

Sometimes to make existing code compatible with backends that are not fully standard compliant, we would need to create copies where the original code would not.

For example, as a workaround for gh-144 (PyTorch doesn't support negative step), we could (sometimes) make the replacement:

from array_api_compat import torch
x = torch.arange(10)
# x[::-1]  
xp.flip(x)  # ValueError: step must be greater than zero

As a workaround for JAX not supporting mutation, we could sometimes make replacements like:

from array_api_compat import jax
# x is an array, i is a mask with the same shape
# x[i] = 0
x = jax.where(i, 0, x)

However, making substitutions like this could decrease performance for array types that do support the desired operation (returning a view or mutating the original, in these cases).

Functions that perform the desired operation when possible and the substitute otherwise (e.g. https://github.com/scipy/scipy/pull/20085#issuecomment-2119307509) have been proposed. Do such things belong in array_api_compat?

asmeurer commented 1 month ago

Regarding flip(x) vs. x[::-1], note that np.flip does return a view. Torch only doesn't because it doesn't support views with negative strides at all. So just using flip is portable in the manner you suggest. It just isn't as readable (though even that could be argued), but there's no real way around that since we don't want to wrap array/tensor objects in this library. More arbitrary slices would also be more cumbersome to write, and we can definitely build a helper around that if it is needed.

What would helper APIs for "mutate if you can" look like?

You'd have to be really careful using any API like that. A mutation and a copy are very different things and you'd need to make sure you write code that works correctly in both instances. I guess the way to handle it is to never actually rely on mutation semantically. Rather, mutation should just be treated as an implementation detail for performance. In other words, write code that treats everything as immutable but using functions that can mutate or "copy on write" when possible.

JAX effectively does this internally, where non-aliased arrays are free to be mutated. I wonder if this request effectively amounts to "rewrite JAX on top of any array API library". If so, it might be very difficult or even impossible to achieve. For instance, even in NumPy we don't currently have the ability to fully track aliasing (views are only tracked in one direction).

CC @rgommers

mdhaber commented 1 month ago

Regarding flip(x) vs. x[::-1], note that np.flip does return a view.

Ah, I don't think I noticed that. Good to know.

Torch only doesn't because it doesn't support views with negative strides at all....

Isn't that problematic for the same reasons you mentioned? "You'd have to be really careful using any API like that."

Sure, we developers would need to be careful. In SciPy, we wouldn't pass the dangers on to the user. That is, the function would always return arrays that don't share memory with the inputs, regardless of whether we are doing mutations or copies internally as part of the calculaton. It would only be for circumstances in which either one is acceptable. The alternative is just that JAX is left without that functionality.

What would helper APIs for "mutate if you can" look like?

See the comment linked above : https://github.com/scipy/scipy/pull/20085#issuecomment-2119307509

asmeurer commented 1 month ago

I would say flip returning a view vs. a copy isn't quite the same because it's not actually mutating anything. It's only problematic if you later do an operation that could mutate the aliased memory. If you never do, the semantics are identical.

x[i] = 0 is different, because it itself is a mutation. So you're doing something that already has different semantics depending on whether x's memory is referenced by other arrays or not.

But you're right that this all comes down to the fact that the standard is agnostic about views vs. copies, so portable code always has to be written in a way that doesn't rely on mutation. See https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html (I pinged Ralf here because I know he has thought about that particular document a lot).

mdhaber commented 1 month ago

And he wrote the linked comment with the prototype helper.

asmeurer commented 1 month ago

So you're referring to this function?

def at_set(
        x : Array,
        idx: Array | int | slice,
        val: Array | int | float | complex,
        *,
        xp: ModuleType | None = None,
    ) -> Array:
    """In-place update. Use only if no views are involved."""
    xp = array_namespace(x) if xp is None else xp
    if is_jax(xp):
        if xp.isdtype(idx.dtype, 'bool'):
            x = xp.where(idx, x, val)
        else:
            x = x.at[idx].set(val)
    else:
        x[idx] = val
    return x

That seems fine (except I would make a few minor changes), as long as it's well documented that it might or might not actually mutate x.

mdhaber commented 1 month ago

Yes, and yeah, documenting that would be key.