Open mdhaber opened 4 months ago
Regarding flip(x)
vs. x[::-1]
, note that np.flip
does return a view. Torch only doesn't because it doesn't support views with negative strides at all. So just using flip
is portable in the manner you suggest. It just isn't as readable (though even that could be argued), but there's no real way around that since we don't want to wrap array/tensor objects in this library. More arbitrary slices would also be more cumbersome to write, and we can definitely build a helper around that if it is needed.
What would helper APIs for "mutate if you can" look like?
You'd have to be really careful using any API like that. A mutation and a copy are very different things and you'd need to make sure you write code that works correctly in both instances. I guess the way to handle it is to never actually rely on mutation semantically. Rather, mutation should just be treated as an implementation detail for performance. In other words, write code that treats everything as immutable but using functions that can mutate or "copy on write" when possible.
JAX effectively does this internally, where non-aliased arrays are free to be mutated. I wonder if this request effectively amounts to "rewrite JAX on top of any array API library". If so, it might be very difficult or even impossible to achieve. For instance, even in NumPy we don't currently have the ability to fully track aliasing (views are only tracked in one direction).
CC @rgommers
Regarding flip(x) vs. x[::-1], note that np.flip does return a view.
Ah, I don't think I noticed that. Good to know.
Torch only doesn't because it doesn't support views with negative strides at all....
Isn't that problematic for the same reasons you mentioned? "You'd have to be really careful using any API like that."
Sure, we developers would need to be careful. In SciPy, we wouldn't pass the dangers on to the user. That is, the function would always return arrays that don't share memory with the inputs, regardless of whether we are doing mutations or copies internally as part of the calculaton. It would only be for circumstances in which either one is acceptable. The alternative is just that JAX is left without that functionality.
What would helper APIs for "mutate if you can" look like?
See the comment linked above : https://github.com/scipy/scipy/pull/20085#issuecomment-2119307509
I would say flip
returning a view vs. a copy isn't quite the same because it's not actually mutating anything. It's only problematic if you later do an operation that could mutate the aliased memory. If you never do, the semantics are identical.
x[i] = 0
is different, because it itself is a mutation. So you're doing something that already has different semantics depending on whether x
's memory is referenced by other arrays or not.
But you're right that this all comes down to the fact that the standard is agnostic about views vs. copies, so portable code always has to be written in a way that doesn't rely on mutation. See https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html (I pinged Ralf here because I know he has thought about that particular document a lot).
And he wrote the linked comment with the prototype helper.
So you're referring to this function?
def at_set(
x : Array,
idx: Array | int | slice,
val: Array | int | float | complex,
*,
xp: ModuleType | None = None,
) -> Array:
"""In-place update. Use only if no views are involved."""
xp = array_namespace(x) if xp is None else xp
if is_jax(xp):
if xp.isdtype(idx.dtype, 'bool'):
x = xp.where(idx, x, val)
else:
x = x.at[idx].set(val)
else:
x[idx] = val
return x
That seems fine (except I would make a few minor changes), as long as it's well documented that it might or might not actually mutate x
.
Yes, and yeah, documenting that would be key.
Follow-up to https://github.com/data-apis/array-api-compat/issues/144#issuecomment-2161491665
Sometimes to make existing code compatible with backends that are not fully standard compliant, we would need to create copies where the original code would not.
For example, as a workaround for gh-144 (PyTorch doesn't support negative
step
), we could (sometimes) make the replacement:As a workaround for JAX not supporting mutation, we could sometimes make replacements like:
However, making substitutions like this could decrease performance for array types that do support the desired operation (returning a view or mutating the original, in these cases).
Functions that perform the desired operation when possible and the substitute otherwise (e.g. https://github.com/scipy/scipy/pull/20085#issuecomment-2119307509) have been proposed. Do such things belong in
array_api_compat
?