Closed rgommers closed 4 years ago
@honnibal (of spaCy fame) left a really good comment on this topic on the Twitter announcement thread (https://twitter.com/honnibal/status/1295369359653842945):
Btw an API suggestion: I think out= is really unideal as an API. Functions should optionally accept one or more buffers of memory they may use if they help -- or not
The problem with out= is it mixes the efficiency consideration of "you're allowed to do this inplace" with the semantics of the function (the output must be in this array). This gives the user no way to say "Just do whatever's fastest". It's also not nearly general enough
I think reframing things around optionally reused buffers could be a good way to bridge the gap between mutation based and non-mutation based APIs. For example, most of JAX's APIs are based on non-mutating pure functions but JAX has notion of "donated" arguments (donate_argnums
in jax.jit) whose buffers may be reused inplace.
Interestingly, operations like +=
in Python effectively automatically implement "buffer reuse" rather than "inplace" semantics. An operation like x += 1
may or may not reuse the x
buffer, e.g., depending on whether x
is a scalar or NumPy array. This is implemented at the language level: if __iadd__
doesn't exist, Python simply calls __add__
instead, which is how TensorFlow/JAX support inplace arithmetic even if they don't support mutation.
Would it be too much API innovation to add a new buffer
argument rather than supporting out
? This would need a new argument name, unfortunately, because if even if APIs that take out
always return an output value, too, people may be discarding the return value and relying on the inplace semantics instead.
There is also the challenge of how to spell optional buffer reuse for indexing assignment. Perhaps something like x = x.at[idx].set(y, reuse_buffer=True)
, x.at[idx].set_inplace(y)
or x.at[idx].iset(y)
? It's pretty awkward, though.
Would it be too much API innovation to add a new
buffer
argument rather than supportingout
?
Maybe not too much, it's only a single keyword after all. It seems to be an incomplete design though - in the absence of whole-program optimization, the semantics of buffer
are unclear as far as I can tell. How does a "you may use this buffer or you may not" result in predictable code for use of numpy
? To make that concrete, how would you change
x = np.ones(5)
y = np.empty_like(x)
np.sin(x, out=y)
Maybe I'm missing something here, but buffer
looks a bit like the optional "may overwrite" working buffers in LAPACK.
I think this out=
decision is still the least important of all that needs to be decided here though. I won't lose sleep over removing out=
if needed.
There is also the challenge of how to spell optional buffer reuse for indexing assignment. Perhaps something like
x = x.at[idx].set(y, reuse_buffer=True)
,x.at[idx].set_inplace(y)
orx.at[idx].iset(y)
? It's pretty awkward, though.
Yep, that's the harder part. Disregarding out=
, what do you think about option 4 or 5? That to me seems like a compromise where NumPy/PyTorch/CuPy users will need to do a manageable amount of code rewriting, while JAX/TensorFlow/MXNet need to add a wrapper layer (also manageable). And it avoids the problems with inconsistent copy/view returns.
Regarding your option 3 (no mutation at all), I'm afraid that'll be a reason for outright rejecting of the whole standard from the community - it requires a whole change of thinking and a ton of code changes.
Your option 2 (make support for in-place operations optional) seems like the next-best thing. The most likely result will be that functionality that relies on in-place mutation won't support JAX for the time being, until the at[idx].set
becomes available more widely and we add it to the standard.
Would it be too much API innovation
A procedural thought: for this kind of addition of something that doesn't exist yet, we should probably have a separate status in the standard, like "provisional". Like a recommendation to - if one wants to add a feature like that - it must be spelled like X and behave like Y. That guarantees that projects experimenting with this because they like to improve on out=
don't diverge. And it avoids mandating something that may turn out to be sub-optimal. I'm a little uncomfortable with adding new things unless it's quite clear that there'll be no hiccups during implementation.
To make that concrete, how would you change
x = np.ones(5) y = np.empty_like(x) np.sin(x, out=y)
Maybe I'm missing something here, but
buffer
looks a bit like the optional "may overwrite" working buffers in LAPACK.
This might become something like:
x = np.ones(5)
y = np.empty_like(x)
z = np.sin(x, buffer=y)
where the values that end up filling y
are not well defined.
That said, I do think something like "may overwrite" is a better way to spell this than allowing filling into arbitrary buffers, closing in spirit both to working buffers in LAPACK and Python's own inplace arithmetic. As a user, you would write:
x = np.ones(5)
z = np.sin(x, may_overwrite=True)
The values of x
after computing z
would be undefined. In practice, the safe way to write this would thus be to reuse the variable x
, i.e., x = np.sin(x, may_overwrite=True)
.
Now inplace operations now become an optimization detail, rather than a requirement as would be the case for np.sin(x, out=x)
.
where the values that end up filling
y
are not well defined. .... The values ofx
after computingz
would be undefined.
yes, that's the footgun I was worried about
In practice, the safe way to write this would thus be to reuse the variable
x
, i.e.,x = np.sin(x, may_overwrite=True)
.
That's better, I can see the appear of that. A boolean keyword may not generalize well for multiple inputs, which I think was a point @honnibal was making with "one or more buffers". It's also why out=
takes a tuple. That said, I've never seen that actually used, so it may not be relevant.
So x = np.sin(x, may_overwrite=True)
is likely an improvement over out=
. It can raise an error if x
is a view. It cannot know if a user does x =
or z =
, and the latter may lead to bugs, which is unfortunate. Static analysis tools may help avoid such bugs perhaps.
It seems like the may_overwrite
or the buffer
arg could be calculated post facto by a sophisticated enough JIT, no?
@rgommers what do you mean by "in-place operations that are unambiguous"? Is this referring specifically to indexing based assignment and out
?
Making views immutable sounds pretty reasonable sounds like a pretty reasonable design decision. I don't know if current array libraries would be happy with that (there are certainly existing cases for views are written to), but at the least this sounds like exactly the sort of design decision we should allow. That said, if we allow "immutable" arrays in the case of views, why not allow raising an error for mutating arrays under other circumstances, too?
In that case, I don't think either (4) or (5) is much better than (2), making in-place operations optional, with an understanding that this will make it harder to write efficient generic logic in some cases (but this is likely already impossible in general). To elaborate:
out
may raise TypeError
(either explicitly or due to a missing keyword argument). All such arrays should declare themselves as immutable in some standard way, i.e., the equivalent of ndarray.flags.writable
from NumPy.Examples of in-place arithmetic:
def a(x): # bad, assumes x is modified in-place
x += 1
x /= 2
def b(x): # bad, assumes x is copied by +=
x += 1
return x/2
def c(x): # good, in-place arithmetic is only used as a hint within a function
x = x + 1
x /= 2
return x
def d(x): # good, premature optimization is the root of all evil :)
return (x + 1) / 2
I think users will have to learn these rules themselves. Python function boundaries generally aren't available at the level of array libraries, so we can't do much to help them. Fortunately, these are already best practices for writing code that works generically over both arrays and scalars.
Potentially static type checkers could also catch this sort of thing?
It seems like the
may_overwrite
or thebuffer
arg could be calculated post facto by a sophisticated enough JIT, no?
Yes, in fact this is exactly the case in JAX ):
@rgommers what do you mean by "in-place operations that are unambiguous"? Is this referring specifically to indexing based assignment and
out
?
Yes indeed (unambiguous unless the target affects a view).
That said, if we allow "immutable" arrays in the case of views, why not allow raising an error for mutating arrays under other circumstances, too?
I think the differences are (a) the impact on existing code (acceptable for immutable views, a train wreck for all mutations), and (b) how easy the resulting code changes are to make (for views, just add a .copy()
).
Here's a search result for some forms of image assigment use (:] =
and 0] =
) for scikit-image:
That tells me that if mutation is optional, scikit-image is simply going to ignore the array types that don't implement it. As will SciPy et al. Maybe that just is what it is, but what I want to know if it can be avoided by JAX adding a simple wrapper that maps
x[idx] = y
to
x = x.at[idx].set(y)
If that's not that hard to do, and if a lot of those cases are not inefficient (which I think is the case), then would it really be so bad adding that wrapper?
The reason why x[idx] = y
in JAX is challenging ultimately comes down to confusion between copies vs views. Specifically, consider f(x)
:
def f(x):
x[:, 0] = 1 # something in-place
...
f(x)
operates in-place, but jit(f)(x)
would not (jit
makes copies of array arguments).
So if we can figure out a usage pattern that avoids that confusion by recommending (or requiring?) a call to ensure a copy, I think we could make it work in JAX, e.g.,
y = x.copy()
y[:, 0] = z
If we have a reliable way to detect "views" (with reference counting?), we might be able to call this method ensure_copy()
or mutable()
rather than copy()
.
JAX would probably choose to make all operations "views" by default, just to preserve maximum flexibility and require best practices. In fact, I would guess that the default array type wouldn't even implement a __setitem__
-- you'd always have to call .copy()
to get mutable array.
If we have a reliable way to detect "views" (with reference counting?), we might be able to call this method
ensure_copy()
ormutable()
rather thancopy()
.
That's an interesting idea. I think it's feasible to build this, and the extra API surface of mutable()
and the code changes needed in libraries using __setitem__
is a small price to pay for making this work well across all array libraries.
A quick test to set a baseline, here are the results of running the test suites of some SciPy modules after making numpy.asarray
and numpy.random.random
return readonly arrays:
stats
: 608 failed, 879 passedndimage
: 122 failed, 380 passedoptimize
: 345 failed, 1033 passedinterpolate
: 162 failed, 209 passedTesting the impact of the mutable()
variant is a lot more difficult, but I'll see if I can find a volunteer for that.
Coming back to this after finding it's probably too hard to implement .mutable()
in the form discussed above.
In-place arithmetic should be supported by all arrays, but should not be assumed by users to either manipulate generic array values either in-place or with making copying. Instead, in-place arithmetic is a hint that it is safe to overwrite array values as a potential optimization. (As I noted earlier, this is basically decided already at the Python language level.)
Played with this some more:
import numpy as np
import jax.numpy as jnp
def f(mod):
x = mod.arange(5)
y = x[:2]
y += 1
return x
print(f(np))
print(f(jnp))
yields
[1 2 2 3 4]
[0 1 2 3 4]
This is actually 100% identical to the problem with slice assignment. So saying for +=
and other in-place operators "users will have to learn these rules themselves", but for slice assignment that we need a new usage pattern like .mutable()
is pretty inconsistent.
As I noted earlier, this is basically decided already at the Python language level.
It seems to me that if JAX were principled on no mutation, it should have implemented __iadd__
and raised a TypeError
from it. That Python has some rule to forward a dunder method is irrelevant from a user perspective.
@mattip is going to help with an experiment, creating a NumPy branch which sets flags.writeable
to False on creation of a view (without setting it back to True if the view goes out of scope). That seems feasible, because it should happen in the same place where the .base
attribute is set.
This is actually 100% identical to the problem with slice assignment.
I'm not entirely sure I follow what you mean here.
Examples like your f(mod)
are certainly not a great idea in any array library. Raising errors when attempting to write to views certainly seems like a good start.
So saying for
+=
and other in-place operators "users will have to learn these rules themselves", but for slice assignment that we need a new usage pattern like.mutable()
is pretty inconsistent.
If it helps, we could still call it .copy()
? If you don't know if any expression (or input into a function) is going to give a view or a copy, then you need to make a defensive copy anyways in libraries like NumPy. This is not going to be maximally efficient but may be good enough for much generic code, and it's already done all over the place in idiomatic NumPy code. In fact it's already pretty hard to know when you'll get views vs copies unless you have an expert level knowledge of NumPy.
My main concern is that we should try to avoid baking specific view vs copy semantics directly into the standard. For example, if a new library wants to implement arithmetic as lazy expressions (like in dask or xtensor) that should be OK. Lazy expressions are effectively a form of views, so you can't count on being able to write something like this, even though it's safe in NumPy:
y = x + 1
y[:2] = 0 # this should raise an error, if y is a "view"
My main concern is that we should try to avoid baking specific view vs copy semantics directly into the standard.
Yes, I agree completely. The question is just what form that takes. I'm trying to get to a place where we don't have to guess at all, but slice assignment is still allowed.
I'm not entirely sure I follow what you mean here.
x += 1
is an in-place operation; it can be translated to a non-mutating form (x = ...
) reliably if no views are involved.x[idx] = 1
is an in-place operation; it can be translated to a non-mutating form (x = ...
) reliably if no views are involved.Here is a related PyTorch issue with ideas and rationale for adding immutable tensors and returning them in some cases where view/copy behaviour will be hard to predict: https://github.com/pytorch/pytorch/issues/44027
Here is a hacky patch for NumPy that makes both the base array and the view read-only if a view is created:
It doesn't completely do the job it's supposed to do - it works for warning when doing a mutating operation that affects a view, however it also prevents regular slice assignment:
In [1]: x = np.ones((2, 3), dtype=np.float64)
In [2]: x[0, 0] = 1
In [3]: x[:2] = 1
... DeprecationWarning: Numpy has detected that you (may be) writing to an array with overlapping memory ...
In [4]: x.flags.writeable
Out[4]: True
In [5]: y = x[0, :]
In [6]: x.flags.writeable
... FutureWarning: future versions will not create a writeable array ...
Out[6]: True
In [7]: x.flags
Out[7]:
C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : True
WRITEABLE : True (with WARN_ON_WRITE=True)
ALIGNED : True
WRITEBACKIFCOPY : False
UPDATEIFCOPY : False
In [8]: y.flags
Out[8]:
C_CONTIGUOUS : True
F_CONTIGUOUS : True
OWNDATA : False
WRITEABLE : True (with WARN_ON_WRITE=True)
ALIGNED : True
WRITEBACKIFCOPY : False
UPDATEIFCOPY : False
That could be prevented by starting from within __setitem__
and passing down the information that the .writeable
flag should not be touched to PyArray_SetBaseObject
.
At that point there's still annoying behaviour left though, for example just evaluating x[:-1]
in IPython will create a view that the interpreter holds on to - and hence flip the flag.
Trying out the effect on SciPy:
python -c "from scipy import ndimage; ndimage.test(extra_argv=['-W', 'error::DeprecationWarning'])"
gives 168 failed, 332 passed
. Many failures are due to NumPy functions triggering the warning though - for example np.eye
chokes on m[:M-k].flat[i::M+1] = 1
in its implementation (fixing that gets rid of 18 test failures already). Fixing that and a similar issue in np.indices
gets me up to 126 failed, 374 passed
. Still, that's a lot of failures - and some other SciPy modules won't even run with "error on deprecation", because something fails before pytest is done with test discovery.
tl;dr this is going to be a real pain to do.
I wrote some tests that make it easier to figure out which libraries have matching behaviour for in-place operations:
Results in:
f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12
numpy 1 2 2 2 2 2 2 3 0 1 1 2
pytorch 1 2 2 2 2 2 2 3 0 1 1 2
MXNet 1 2 2 2 1 2 2 3 0 1 1 2
cupy 1 2 2 2 2 2 2 3 0 1 1 2
dask 1 3 1 1 1 1 -9 3 0 1 1 1
tensorflow 1 3 1 1 1 1 -9 3 0 -9 -9 1
jax 1 3 1 1 1 1 -9 3 0 1 1 1
f1: Add, then in-place subtract
f2: In-place add, alias, then in-place subtract
f3: Slice, then in-place add on slice
f4: Reshape, then in-place add
f5: Slice with step size 2, then in-place add
f6: Alias, then in-place add
f7: Check which array types support slice assignment syntax
f8: Do the actual slice assignment in the way each library wants it done
f9: `diag` is known to have inconsistent behaviour, test it
f10: Indexing with list of integers, then in-place add
f11: Boolean indexing, then in-place add
f12: Indexing with ellipsis, then in-place add
Check if behaviour equals that of NumPy
(1 or -1 means copy/view behaviour mismatch, -9 means unsupported behaviour):
f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12
numpy 0 0 0 0 0 0 0 0 0 0 0 0
pytorch 0 0 0 0 0 0 0 0 0 0 0 0
MXNet 0 0 0 0 -1 0 0 0 0 0 0 0
cupy 0 0 0 0 0 0 0 0 0 0 0 0
dask 0 1 -1 -1 -1 -1 -9 0 0 0 0 -1
tensorflow 0 1 -1 -1 -1 -1 -9 0 0 -9 -9 -1
jax 0 1 -1 -1 -1 -1 -9 0 0 0 0 -1
Some observations:
[::2, :]
. (there will be more mismatches for individual functions and more complex cases)y = x; y -= 1
).I'm leaning towards the following choices:
out=
completely, and don't replace it with anythingHere is a hacky patch for NumPy that makes both the base array and the view read-only if a view is created:
Would it be more practical to make only the view read-only? If you modify the base array, then in addition to errors like the IPython example you show, functions that create views would now effectively be mutating their input arguments.
Yes, that's a good point. I undid that change, and added some more fixes (current patch here). That at least makes the NumPy test suite start and run to completion without crashing. Resulting in:
1268 failed, 10174 passed, 67 skipped, 10 errors
The troubles are:
flags.writeable = True
I think it would take too much time to complete this change and make the NumPy test suite (mostly) pass - and that'd be needed to be able to assess the impact on SciPy and other packages.
That, plus that it's a massive backwards compat break for all of NumPy, PyTorch, CuPy and MXNet, makes me think we should simply go with "Add a recommendation that users avoid any mutating operation when a view may be involved".
Everyone in the meeting today was good with points 1-4 in https://github.com/data-apis/array-api/issues/24#issuecomment-689907016. I'll open a PR that discusses mutability, and @kgryte will remove out=
.
Context:
That issue and PR were about unrelated topics, so I'll try to summarize the copy-view and mutation topic here and we can continue the discussion.
Note that the two topics are fairly coupled, because copy/view differences only matter (for semantics, not for performance) when mixed with mutation.
Mutating arrays
There's a number of things that may rely on mutation:
+=
,*=
out=
keyword argument__setitem__
Summary of the issue with mutation by @shoyer was: Mutation can be challenging to support in some execution models (at least without another layer of indirection), which is why several projects currently don't support it (TensorFlow and JAX) or only support it half-heartedly (e.g., Dask). The commonality between these libraries is that they build up abstract computations, which is then transformed (e.g., for autodiff) and/or executed in parallel. Even NumPy has "read only" arrays. I'm particularly concerned about new projects that implement this API, which might find the need to support mutation burdensome.
@alextp said: TensorFlow was planning to add mutability and didn't see a real issue with supporting
out=
.@shoyer said: It's definitely always possible to support mutation at the Python level via some sort of wrapper layer.
dask.array
is perhaps a good example of this. It supports mutating operations and out in some cases, but its support for mutation is still rather limited. For example, it doesn't support assignment likex[:2, :] = some_other_array
.Working around limitations of no support for mutation can usually be done by one of:
where
for selection, e.g.,where(arange(4) == 2, 1, 0)
y = array([0, 1]); x = y[[0, 0, 1, 0]]
in this caseSome version of (2) always works, though it can be tricky to work out (especially with current APIs). The duality between indexing and assignment is the difference between specifying where elements come from or where they end up.
The JAX syntax for slice assignment is:
x.at[idx].set(y) vs x[idx] = y
One advantage of the non-mutating version is that JAX can have reliable assigning arithmetic on array slices with
x.at[idx].add(y)
(x[idx] += y
doesn't work ifx[idx]
returns a copy).A disadvantage is that doing this sort thing inside a loop is almost always a bad idea unless you have a JIT compiler, because every indexing assignment operation makes a full copy. So the naive translation of an efficient Python loop that fills out an array row by row would now make a copy in each step. Instead, you'd have to rewrite that loop to use something like concatenate instead (which in my experience is already about as efficient as using indexing assignment).
Copy-view behaviour
Libraries like NumPy and PyTorch return views where possible from function calls. It's sometimes hard to predict when a view will be returned vs. when a copy - it not only depends on the function in question, but also on whether the input array is contiguous, and sometimes even on input dtype.
This is one place where it's hard to avoid implementation choices leaking into the API:
transpose()
.transpose()
).The above copy vs. view difference starts leaking into the API - i.e., the same code starts giving different results for different implementations - when it is combined with an operation that performs in-place mutation of an array (either the base array or the view on it). In the absence of that combination, views are simply a performance optimization that's invisible to the user.
The question is whether copy-view differences should be allowed, and if so how to deal with the semantics that vary between libraries.
To answer whether is should be allowed, let's first ask how often the combination of views and mutation is used. A few observations:
*=
,+=
and] =
in SciPy and scikit-learn.py
files shows that in-place mutation inside functions is heavily used.+= 1
) and mutating part of an array (e.g. withx[:, :2] = y
). The former is a lot easier to support for array libraries employing static graphs or a JIT than the latter. See the discussion at https://github.com/data-apis/array-api/issues/8#issuecomment-673202302 for details.Options for how to standardize
In https://github.com/data-apis/array-api/issues/8 @shoyer listed the following options for how to deal with mutability:
ndarray.flags.writeable
. (From later discussion, see https://github.com/data-apis/array-api/issues/8#issuecomment-674514340 for the implication of that for users of the API).To that I'd like to add a more granular option:
Require support for in-place operations that are unambiguous, and require raising an exception in case a view is mutated.
Rationale:
(a) This would require libraries that don't support mutation to write a wrapper layer, but the behaviour would be unambiguous and in most cases the wrapper would not be inefficient. (b) In case inefficient mutation is detected (e.g. mutation a large array row-by-row in a loop), a warning may be emitted.
A variant of this option would be:
Require support for in-place operations that are unambiguous and mutate the whole array at once (i.e.
+=
andout=
must be supported, element/slice assignment must raise an exception), and require raising an exception in case a view is mutated.Trade-off here is ease of implementation for libraries like Dask and JAX vs. putting a rewrite burden on SciPy et al. and a usability burden on end users (the alternative to element/slice assignment is unintuitive).