data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
206 stars 43 forks source link

Copy-view behaviour and mutating arrays #24

Closed rgommers closed 3 years ago

rgommers commented 3 years ago

Context:

That issue and PR were about unrelated topics, so I'll try to summarize the copy-view and mutation topic here and we can continue the discussion.

Note that the two topics are fairly coupled, because copy/view differences only matter (for semantics, not for performance) when mixed with mutation.

Mutating arrays

There's a number of things that may rely on mutation:

Summary of the issue with mutation by @shoyer was: Mutation can be challenging to support in some execution models (at least without another layer of indirection), which is why several projects currently don't support it (TensorFlow and JAX) or only support it half-heartedly (e.g., Dask). The commonality between these libraries is that they build up abstract computations, which is then transformed (e.g., for autodiff) and/or executed in parallel. Even NumPy has "read only" arrays. I'm particularly concerned about new projects that implement this API, which might find the need to support mutation burdensome.

@alextp said: TensorFlow was planning to add mutability and didn't see a real issue with supporting out=.

@shoyer said: It's definitely always possible to support mutation at the Python level via some sort of wrapper layer.

dask.array is perhaps a good example of this. It supports mutating operations and out in some cases, but its support for mutation is still rather limited. For example, it doesn't support assignment like x[:2, :] = some_other_array.

Working around limitations of no support for mutation can usually be done by one of:

  1. Use where for selection, e.g., where(arange(4) == 2, 1, 0)
  2. Calculate the "inverse" of the assignment operator in terms of indexing, e.g., y = array([0, 1]); x = y[[0, 0, 1, 0]] in this case

Some version of (2) always works, though it can be tricky to work out (especially with current APIs). The duality between indexing and assignment is the difference between specifying where elements come from or where they end up.

The JAX syntax for slice assignment is: x.at[idx].set(y) vs x[idx] = y

One advantage of the non-mutating version is that JAX can have reliable assigning arithmetic on array slices with x.at[idx].add(y) (x[idx] += y doesn't work if x[idx] returns a copy).

A disadvantage is that doing this sort thing inside a loop is almost always a bad idea unless you have a JIT compiler, because every indexing assignment operation makes a full copy. So the naive translation of an efficient Python loop that fills out an array row by row would now make a copy in each step. Instead, you'd have to rewrite that loop to use something like concatenate instead (which in my experience is already about as efficient as using indexing assignment).

Copy-view behaviour

Libraries like NumPy and PyTorch return views where possible from function calls. It's sometimes hard to predict when a view will be returned vs. when a copy - it not only depends on the function in question, but also on whether the input array is contiguous, and sometimes even on input dtype.

This is one place where it's hard to avoid implementation choices leaking into the API:

The above copy vs. view difference starts leaking into the API - i.e., the same code starts giving different results for different implementations - when it is combined with an operation that performs in-place mutation of an array (either the base array or the view on it). In the absence of that combination, views are simply a performance optimization that's invisible to the user.

The question is whether copy-view differences should be allowed, and if so how to deal with the semantics that vary between libraries.

To answer whether is should be allowed, let's first ask how often the combination of views and mutation is used. A few observations:

  1. It is normally considered a bug if a library function (e.g. a SciPy or scikit-learn one) mutates any of its input arguments - unless the function is explicitly documented as doing so, which is rare. So the main concern is use inside functions, with arrays that are either created inside the function or use a copy of the input array.
  2. A search for patterns like *=, += and ] = in SciPy and scikit-learn .py files shows that in-place mutation inside functions is heavily used.
  3. There's a significant difference between mutating a complete array (e.g. with += 1) and mutating part of an array (e.g. with x[:, :2] = y). The former is a lot easier to support for array libraries employing static graphs or a JIT than the latter. See the discussion at https://github.com/data-apis/array-api/issues/8#issuecomment-673202302 for details.
  4. It's harder to figure out how often the combination of mutating part of an array and that mutation affecting a view occurs. This could be tested though, with a patched NumPy to raise an exception on mutations affecting a view and then running test suites of downstream libraries.

Options for how to standardize

In https://github.com/data-apis/array-api/issues/8 @shoyer listed the following options for how to deal with mutability:

  1. Require support for in-place operations. Libraries that don't support mutation fully will need to write a wrapper layer, even if it would be inefficient.
  2. Make support for in-place operations optional. Arrays can indicate whether they support mutation via some standard API, e.g., like NumPy's ndarray.flags.writeable. (From later discussion, see https://github.com/data-apis/array-api/issues/8#issuecomment-674514340 for the implication of that for users of the API).
  3. Don't include support for in-place operations in the spec. This is a conservative choice, one which might have negative performance consequences (but it's a little hard to say without looking carefully). At the very least, it might require a library like SciPy to retain a special path for numpy.ndarray objects.

To that I'd like to add a more granular option:

  1. Require support for in-place operations that are unambiguous, and require raising an exception in case a view is mutated.

    Rationale:

    (a) This would require libraries that don't support mutation to write a wrapper layer, but the behaviour would be unambiguous and in most cases the wrapper would not be inefficient. (b) In case inefficient mutation is detected (e.g. mutation a large array row-by-row in a loop), a warning may be emitted.

A variant of this option would be:

  1. Require support for in-place operations that are unambiguous and mutate the whole array at once (i.e. += and out= must be supported, element/slice assignment must raise an exception), and require raising an exception in case a view is mutated.

    Trade-off here is ease of implementation for libraries like Dask and JAX vs. putting a rewrite burden on SciPy et al. and a usability burden on end users (the alternative to element/slice assignment is unintuitive).

shoyer commented 3 years ago

@honnibal (of spaCy fame) left a really good comment on this topic on the Twitter announcement thread (https://twitter.com/honnibal/status/1295369359653842945):

Btw an API suggestion: I think out= is really unideal as an API. Functions should optionally accept one or more buffers of memory they may use if they help -- or not

The problem with out= is it mixes the efficiency consideration of "you're allowed to do this inplace" with the semantics of the function (the output must be in this array). This gives the user no way to say "Just do whatever's fastest". It's also not nearly general enough

I think reframing things around optionally reused buffers could be a good way to bridge the gap between mutation based and non-mutation based APIs. For example, most of JAX's APIs are based on non-mutating pure functions but JAX has notion of "donated" arguments (donate_argnums in jax.jit) whose buffers may be reused inplace.

Interestingly, operations like += in Python effectively automatically implement "buffer reuse" rather than "inplace" semantics. An operation like x += 1 may or may not reuse the x buffer, e.g., depending on whether x is a scalar or NumPy array. This is implemented at the language level: if __iadd__ doesn't exist, Python simply calls __add__ instead, which is how TensorFlow/JAX support inplace arithmetic even if they don't support mutation.

Would it be too much API innovation to add a new buffer argument rather than supporting out? This would need a new argument name, unfortunately, because if even if APIs that take out always return an output value, too, people may be discarding the return value and relying on the inplace semantics instead.

There is also the challenge of how to spell optional buffer reuse for indexing assignment. Perhaps something like x = x.at[idx].set(y, reuse_buffer=True), x.at[idx].set_inplace(y) or x.at[idx].iset(y)? It's pretty awkward, though.

rgommers commented 3 years ago

Would it be too much API innovation to add a new buffer argument rather than supporting out?

Maybe not too much, it's only a single keyword after all. It seems to be an incomplete design though - in the absence of whole-program optimization, the semantics of buffer are unclear as far as I can tell. How does a "you may use this buffer or you may not" result in predictable code for use of numpy? To make that concrete, how would you change

x = np.ones(5)                                                        
y = np.empty_like(x)                                                  
np.sin(x, out=y) 

Maybe I'm missing something here, but buffer looks a bit like the optional "may overwrite" working buffers in LAPACK.

I think this out= decision is still the least important of all that needs to be decided here though. I won't lose sleep over removing out= if needed.

There is also the challenge of how to spell optional buffer reuse for indexing assignment. Perhaps something like x = x.at[idx].set(y, reuse_buffer=True), x.at[idx].set_inplace(y) or x.at[idx].iset(y)? It's pretty awkward, though.

Yep, that's the harder part. Disregarding out=, what do you think about option 4 or 5? That to me seems like a compromise where NumPy/PyTorch/CuPy users will need to do a manageable amount of code rewriting, while JAX/TensorFlow/MXNet need to add a wrapper layer (also manageable). And it avoids the problems with inconsistent copy/view returns.

Regarding your option 3 (no mutation at all), I'm afraid that'll be a reason for outright rejecting of the whole standard from the community - it requires a whole change of thinking and a ton of code changes.

Your option 2 (make support for in-place operations optional) seems like the next-best thing. The most likely result will be that functionality that relies on in-place mutation won't support JAX for the time being, until the at[idx].set becomes available more widely and we add it to the standard.

rgommers commented 3 years ago

Would it be too much API innovation

A procedural thought: for this kind of addition of something that doesn't exist yet, we should probably have a separate status in the standard, like "provisional". Like a recommendation to - if one wants to add a feature like that - it must be spelled like X and behave like Y. That guarantees that projects experimenting with this because they like to improve on out= don't diverge. And it avoids mandating something that may turn out to be sub-optimal. I'm a little uncomfortable with adding new things unless it's quite clear that there'll be no hiccups during implementation.

shoyer commented 3 years ago

To make that concrete, how would you change

x = np.ones(5)                                                        
y = np.empty_like(x)                                                  
np.sin(x, out=y) 

Maybe I'm missing something here, but buffer looks a bit like the optional "may overwrite" working buffers in LAPACK.

This might become something like:

x = np.ones(5)
y = np.empty_like(x)
z = np.sin(x, buffer=y)

where the values that end up filling y are not well defined.

That said, I do think something like "may overwrite" is a better way to spell this than allowing filling into arbitrary buffers, closing in spirit both to working buffers in LAPACK and Python's own inplace arithmetic. As a user, you would write:

x = np.ones(5)
z = np.sin(x, may_overwrite=True)

The values of x after computing z would be undefined. In practice, the safe way to write this would thus be to reuse the variable x, i.e., x = np.sin(x, may_overwrite=True).

Now inplace operations now become an optimization detail, rather than a requirement as would be the case for np.sin(x, out=x).

rgommers commented 3 years ago

where the values that end up filling y are not well defined. .... The values of x after computing z would be undefined.

yes, that's the footgun I was worried about

In practice, the safe way to write this would thus be to reuse the variable x, i.e., x = np.sin(x, may_overwrite=True).

That's better, I can see the appear of that. A boolean keyword may not generalize well for multiple inputs, which I think was a point @honnibal was making with "one or more buffers". It's also why out= takes a tuple. That said, I've never seen that actually used, so it may not be relevant.

So x = np.sin(x, may_overwrite=True) is likely an improvement over out=. It can raise an error if x is a view. It cannot know if a user does x = or z =, and the latter may lead to bugs, which is unfortunate. Static analysis tools may help avoid such bugs perhaps.

saulshanabrook commented 3 years ago

It seems like the may_overwrite or the buffer arg could be calculated post facto by a sophisticated enough JIT, no?

shoyer commented 3 years ago

@rgommers what do you mean by "in-place operations that are unambiguous"? Is this referring specifically to indexing based assignment and out?

Making views immutable sounds pretty reasonable sounds like a pretty reasonable design decision. I don't know if current array libraries would be happy with that (there are certainly existing cases for views are written to), but at the least this sounds like exactly the sort of design decision we should allow. That said, if we allow "immutable" arrays in the case of views, why not allow raising an error for mutating arrays under other circumstances, too?

In that case, I don't think either (4) or (5) is much better than (2), making in-place operations optional, with an understanding that this will make it harder to write efficient generic logic in some cases (but this is likely already impossible in general). To elaborate:

  1. Explicit indexing based assignment and out may raise TypeError (either explicitly or due to a missing keyword argument). All such arrays should declare themselves as immutable in some standard way, i.e., the equivalent of ndarray.flags.writable from NumPy.
  2. In-place arithmetic should be supported by all arrays, but should not be assumed by users to either manipulate generic array values either in-place or with making copying. Instead, in-place arithmetic is a hint that it is safe to overwrite array values as a potential optimization. (As I noted earlier, this is basically decided already at the Python language level.)

Examples of in-place arithmetic:

def a(x):  # bad, assumes x is modified in-place
    x += 1
    x /= 2

def b(x):  # bad, assumes x is copied by +=
    x += 1
    return x/2

def c(x):  # good, in-place arithmetic is only used as a hint within a function
    x = x + 1
    x /= 2
    return x

def d(x):  # good, premature optimization is the root of all evil :) 
     return (x + 1) / 2

I think users will have to learn these rules themselves. Python function boundaries generally aren't available at the level of array libraries, so we can't do much to help them. Fortunately, these are already best practices for writing code that works generically over both arrays and scalars.

Potentially static type checkers could also catch this sort of thing?

It seems like the may_overwrite or the buffer arg could be calculated post facto by a sophisticated enough JIT, no?

Yes, in fact this is exactly the case in JAX ):

rgommers commented 3 years ago

@rgommers what do you mean by "in-place operations that are unambiguous"? Is this referring specifically to indexing based assignment and out?

Yes indeed (unambiguous unless the target affects a view).

That said, if we allow "immutable" arrays in the case of views, why not allow raising an error for mutating arrays under other circumstances, too?

I think the differences are (a) the impact on existing code (acceptable for immutable views, a train wreck for all mutations), and (b) how easy the resulting code changes are to make (for views, just add a .copy()).

Here's a search result for some forms of image assigment use (:] = and 0] =) for scikit-image:

``` $ grin "\:\] =" */*.py color/colorconv.py: 513 : rgb_from_hdx[2, :] = np.cross(rgb_from_hdx[0, :], rgb_from_hdx[1, :]) 520 : rgb_from_fgx[2, :] = np.cross(rgb_from_fgx[0, :], rgb_from_fgx[1, :]) 527 : rgb_from_bex[2, :] = np.cross(rgb_from_bex[0, :], rgb_from_bex[1, :]) 540 : rgb_from_gdx[2, :] = np.cross(rgb_from_gdx[0, :], rgb_from_gdx[1, :]) 547 : rgb_from_hax[2, :] = np.cross(rgb_from_hax[0, :], rgb_from_hax[1, :]) 560 : rgb_from_bpx[2, :] = np.cross(rgb_from_bpx[0, :], rgb_from_bpx[1, :]) 567 : rgb_from_ahx[2, :] = np.cross(rgb_from_ahx[0, :], rgb_from_ahx[1, :]) 574 : rgb_from_hpx[2, :] = np.cross(rgb_from_hpx[0, :], rgb_from_hpx[1, :]) exposure/_adapthist.py: 330 : view[:, :] = new feature/_canny.py: 287 : good_label[1:] = sums > 0 feature/corner.py: 791 : >>> img[5:, 5:] = 1 873 : b_dot[:] = bxx_y - bxy_x, byy_x - bxy_y 874 : b_edge[:] = byy_y + bxy_x, bxx_x + bxy_y 882 : corners_subpix[i, :] = np.nan, np.nan 917 : corners_subpix[i, :] = y0 + est_dot[0], x0 + est_dot[1] 919 : corners_subpix[i, :] = np.nan, np.nan 921 : corners_subpix[i, :] = y0 + est_edge[0], x0 + est_edge[1] feature/_daisy.py: 121 : dy[:-1, :] = np.diff(image, n=1, axis=0) 133 : hist[i, :, :] = exp(orientation_kappa * cos(grad_ori - o)) 135 : hist[i, :, :] = np.multiply(hist[i, :, :], grad_mag) 142 : hist_smooth[i, j, :, :] = gaussian_filter(hist[j, :, :], 150 : descs[:orientations, :, :] = hist_smooth[0, :, radius:-radius, 159 : descs[idx:idx + orientations, :, :] = hist_smooth[i + 1, :, feature/_hog.py: 36 : g_row[0, :] = 0 37 : g_row[-1, :] = 0 38 : g_row[1:-1, :] = channel[2:, :] - channel[:-2, :] 293 : normalized_blocks[r, c, :] = \ feature/peak.py: 174 : mask[:remove // 2] = mask[-remove // 2:] = False filters/edges.py: 49 : result[0, :] = 0 50 : result[-1, :] = 0 filters/_gabor.py: 90 : g[:] = np.exp(-0.5 * (rotx ** 2 / sigma_x ** 2 + roty ** 2 / sigma_y ** 2)) filters/_rank_order.py: 57 : original_values[1:] = flat_image[1:][is_different] measure/_moments.py: 423 : d[:] = mu[corners2] / mu0 measure/_polygon.py: 138 : circular = np.all(coords[0, :] == coords[-1, :]) morphology/grey.py: 124 : out[:] = crop(out_temp, pad_widths) morphology/greyreconstruct.py: 148 : dims[1:] = np.array(seed.shape) + 2 * padding morphology/selem.py: 322 : bfilter[:] = 1 novice/_novice.py: 474 : self.xy_array[:] = value restoration/inpaint.py: 110 : >>> mask[2, 2:] = 1 111 : >>> mask[1, 3:] = 1 112 : >>> mask[0, 4:] = 1 segmentation/active_contour_model.py: 110 : edge[i][0, :] = edge[i][1, :] 111 : edge[i][-1, :] = edge[i][-2, :] 148 : A[0, :] = 0 149 : A[1, :] = 0 154 : A[-1, :] = 0 155 : A[-2, :] = 0 156 : A[-2, -3:] = [1, -2, 1] 160 : A[0, :] = 0 162 : A[1, :] = 0 167 : A[-1, :] = 0 168 : A[-1, -3:] = [1, -2, 1] 169 : A[-2, :] = 0 170 : A[-2, -4:] = [-1, 3, -3, 1] 211 : xsave[j, :] = x 212 : ysave[j, :] = y segmentation/_join.py: 131 : inverse_map[(offset - 1):] = labels segmentation/morphsnakes.py: 37 : _P3[1][:, 1, :] = 1 38 : _P3[2][1, :, :] = 1 43 : _P3[7][[0, 1, 2], [0, 1, 2], :] = 1 44 : _P3[8][[0, 1, 2], [2, 1, 0], :] = 1 transform/_geometric.py: 893 : out[simplex == -1, :] = -1 901 : out[index_mask, :] = affine(coords[index_mask, :]) 928 : out[simplex == -1, :] = -1 936 : out[index_mask, :] = affine(coords[index_mask, :]) transform/hough_transform.py: 200 : >>> img[30, :] = 1 transform/radon_transform.py: 215 : fourier_filter[1:] = fourier_filter[1:] * np.sin(omega[1:]) / omega[1:] 223 : fourier_filter[:] = 1 transform/_warps.py: 529 : a[:] = b[:, :, np.newaxis] 531 : a[:] = b util/_montage.py: 132 : arr_out[slices_row[idx_sr], slices_col[idx_sc], :] = image util/_regular_grid.py: 63 : stepsizes[dim + 1:] = ((space_size / n_points) ** $ grin "0\] =" */*.py color/colorconv.py: 273 : idx = (arr[:, :, 0] == out_v) 274 : out[idx, 0] = (arr[idx, 1] - arr[idx, 2]) / delta[idx] 278 : out[idx, 0] = 2. + (arr[idx, 2] - arr[idx, 0]) / delta[idx] 282 : out[idx, 0] = 4. + (arr[idx, 0] - arr[idx, 1]) / delta[idx] 289 : out[:, :, 0] = out_h 643 : arr[arr < 0] = 0 draw/_random_shapes.py: 94 : if shape[0] == 1 or shape[1] == 1: 144 : if shape[0] == 1 or shape[1] == 1: exposure/exposure.py: 123 : >>> np.alltrue(cdf[0] == np.cumsum(hi[0])/float(image.size)) feature/corner.py: 865 : N_dot[0, 0] = Axx 866 : N_dot[0, 1] = N_dot[1, 0] = - Axy 869 : N_edge[0, 0] = Ayy 870 : N_edge[0, 1] = N_edge[1, 0] = Axy feature/_hog.py: 40 : g_col[:, 0] = 0 feature/orb.py: 88 : >>> img1[40:60, 40:60] = square feature/peak.py: 112 : >>> img2[10, 10, 10] = 1 feature/texture.py: 147 : glcm_sums[glcm_sums == 0] = 1 253 : results[mask_0] = 1 feature/util.py: 84 : new_shape1[0] = image2.shape[0] 86 : new_shape2[0] = image1.shape[0] 106 : offset[0] = 0 filters/edges.py: 51 : result[:, 0] = 0 filters/_frangi.py: 58 : lambda1[lambda1 == 0] = 1e-10 123 : filtered[lambdas < 0] = 0 125 : filtered[lambdas > 0] = 0 176 : filtered[lambdas < 0] = 0 181 : out[out <= 0] = 1 filters/_rank_order.py: 56 : original_values[0] = flat_image[0] measure/_find_contours.py: 101 : >>> a[0, 0] = 1 measure/fit.py: 86 : if data.shape[0] == 2: # well determined 729 : >>> data[0] = (100, 100) 765 : >>> dst[0] = (10000, 10000) measure/_polygon.py: 37 : chain[0] = True morphology/greyreconstruct.py: 75 : >>> y_seed[0] = 0.5 149 : dims[0] = 2 morphology/selem.py: 282 : selem[n, 0] = 1 284 : selem[m + n - 1, 0] = 1 333 : selem_rotated[c, 0] = selem_rotated[c, -1] = 1 337 : selem[selem > 0] = 1 morphology/_skeletonize_3d.py: 65 : img_o[img_o != 0] = 1 novice/__init__.py: 79 : >>> picture[0:20, 0:20] = (0, 0, 0) 93 : >>> picture[0:20, 0:20] = (0, 0, 0) novice/_novice.py: 227 : >>> data[:, :, 0] = 255 # Set red component to maximum 242 : >>> pic[0, 0] = (0, 0, 0) restoration/_denoise.py: 207 : slices_p[0] = ax 221 : slices_g[0] = ax restoration/non_local_means.py: 125 : >>> a[10:-10, 10:-10] = 1. segmentation/active_contour_model.py: 112 : edge[i][:, 0] = edge[i][:, 1] 181 : fx[0] = 0 182 : fy[0] = 0 199 : dx[0] = 0 200 : dy[0] = 0 segmentation/_join.py: 127 : forward_map[labels0] = np.arange(offset, offset + len(labels0)) segmentation/morphsnakes.py: 306 : u[aux < 0] = 1 307 : u[aux > 0] = 0 433 : u[aux > 0] = 1 434 : u[aux < 0] = 0 segmentation/random_walker_segmentation.py: 140 : labels[labels == 0] = X transform/finite_radon_transform.py: 61 : f[0] = ai.sum(axis=0) 126 : f[0] = ai.sum(axis=0) transform/_geometric.py: 484 : S[0] = (S[0] + S[1]) / 2.0 662 : A[:rows, 0] = xs transform/hough_transform.py: 23 : (`angles[-1] - angles[0] == PI`). 202 : >>> img[35:45, 35:50] = 1 transform/radon_transform.py: 87 : assert padded_image.shape[0] == padded_image.shape[1] transform/_warps.py: 150 : dst_corners[:, 0] = factors[1] * (src_corners[:, 0] + 0.5) - 0.5 159 : tform.params[1, 0] = 0 442 : xy[..., 0] = x0 + rho * np.cos(theta) 514 : a[:,:,0] = a[:,:,1] = ... = b util/arraycrop.py: 78 : # if arr.shape[1] == 1 and arr.shape[0] == ndims: 81 : # elif arr.shape[0] == ndims: ```

That tells me that if mutation is optional, scikit-image is simply going to ignore the array types that don't implement it. As will SciPy et al. Maybe that just is what it is, but what I want to know if it can be avoided by JAX adding a simple wrapper that maps

x[idx] = y

to

x = x.at[idx].set(y)

If that's not that hard to do, and if a lot of those cases are not inefficient (which I think is the case), then would it really be so bad adding that wrapper?

shoyer commented 3 years ago

The reason why x[idx] = y in JAX is challenging ultimately comes down to confusion between copies vs views. Specifically, consider f(x):

def f(x):
    x[:, 0] = 1  # something in-place
    ...

f(x) operates in-place, but jit(f)(x) would not (jit makes copies of array arguments).

So if we can figure out a usage pattern that avoids that confusion by recommending (or requiring?) a call to ensure a copy, I think we could make it work in JAX, e.g.,

y = x.copy()
y[:, 0] = z

If we have a reliable way to detect "views" (with reference counting?), we might be able to call this method ensure_copy() or mutable() rather than copy().

JAX would probably choose to make all operations "views" by default, just to preserve maximum flexibility and require best practices. In fact, I would guess that the default array type wouldn't even implement a __setitem__ -- you'd always have to call .copy() to get mutable array.

rgommers commented 3 years ago

If we have a reliable way to detect "views" (with reference counting?), we might be able to call this method ensure_copy() or mutable() rather than copy().

That's an interesting idea. I think it's feasible to build this, and the extra API surface of mutable() and the code changes needed in libraries using __setitem__ is a small price to pay for making this work well across all array libraries.

rgommers commented 3 years ago

A quick test to set a baseline, here are the results of running the test suites of some SciPy modules after making numpy.asarray and numpy.random.random return readonly arrays:

Testing the impact of the mutable() variant is a lot more difficult, but I'll see if I can find a volunteer for that.

rgommers commented 3 years ago

Coming back to this after finding it's probably too hard to implement .mutable() in the form discussed above.

In-place arithmetic should be supported by all arrays, but should not be assumed by users to either manipulate generic array values either in-place or with making copying. Instead, in-place arithmetic is a hint that it is safe to overwrite array values as a potential optimization. (As I noted earlier, this is basically decided already at the Python language level.)

Played with this some more:

import numpy as np
import jax.numpy as jnp

def f(mod):
    x = mod.arange(5)
    y = x[:2]
    y += 1
    return x

print(f(np))
print(f(jnp))

yields

[1 2 2 3 4]
[0 1 2 3 4]

This is actually 100% identical to the problem with slice assignment. So saying for += and other in-place operators "users will have to learn these rules themselves", but for slice assignment that we need a new usage pattern like .mutable() is pretty inconsistent.

As I noted earlier, this is basically decided already at the Python language level.

It seems to me that if JAX were principled on no mutation, it should have implemented __iadd__ and raised a TypeError from it. That Python has some rule to forward a dunder method is irrelevant from a user perspective.

rgommers commented 3 years ago

@mattip is going to help with an experiment, creating a NumPy branch which sets flags.writeable to False on creation of a view (without setting it back to True if the view goes out of scope). That seems feasible, because it should happen in the same place where the .base attribute is set.

shoyer commented 3 years ago

This is actually 100% identical to the problem with slice assignment.

I'm not entirely sure I follow what you mean here.

Examples like your f(mod) are certainly not a great idea in any array library. Raising errors when attempting to write to views certainly seems like a good start.

So saying for += and other in-place operators "users will have to learn these rules themselves", but for slice assignment that we need a new usage pattern like .mutable() is pretty inconsistent.

If it helps, we could still call it .copy()? If you don't know if any expression (or input into a function) is going to give a view or a copy, then you need to make a defensive copy anyways in libraries like NumPy. This is not going to be maximally efficient but may be good enough for much generic code, and it's already done all over the place in idiomatic NumPy code. In fact it's already pretty hard to know when you'll get views vs copies unless you have an expert level knowledge of NumPy.

My main concern is that we should try to avoid baking specific view vs copy semantics directly into the standard. For example, if a new library wants to implement arithmetic as lazy expressions (like in dask or xtensor) that should be OK. Lazy expressions are effectively a form of views, so you can't count on being able to write something like this, even though it's safe in NumPy:

y = x + 1
y[:2] = 0  # this should raise an error, if y is a "view"
rgommers commented 3 years ago

My main concern is that we should try to avoid baking specific view vs copy semantics directly into the standard.

Yes, I agree completely. The question is just what form that takes. I'm trying to get to a place where we don't have to guess at all, but slice assignment is still allowed.

I'm not entirely sure I follow what you mean here.

  1. x += 1 is an in-place operation; it can be translated to a non-mutating form (x = ...) reliably if no views are involved.
  2. x[idx] = 1 is an in-place operation; it can be translated to a non-mutating form (x = ...) reliably if no views are involved.
rgommers commented 3 years ago

Here is a related PyTorch issue with ideas and rationale for adding immutable tensors and returning them in some cases where view/copy behaviour will be hard to predict: https://github.com/pytorch/pytorch/issues/44027

rgommers commented 3 years ago

Here is a hacky patch for NumPy that makes both the base array and the view read-only if a view is created:

``` diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index 5da1b5f29..e09065da6 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -231,7 +231,12 @@ PyArray_SetBaseObject(PyArrayObject *arr, PyObject *obj) return -1; } + /* Set writeable flags to warn when writing can affect a view*/ ((PyArrayObject_fields *)arr)->base = obj; + if (PyArray_Check(obj)) { + PyArray_ENABLEFLAGS(obj, NPY_ARRAY_WARN_ON_WRITE); + PyArray_ENABLEFLAGS(arr, NPY_ARRAY_WARN_ON_WRITE); + } return 0; } diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c index 7534c0717..9745f9623 100644 --- a/numpy/core/src/multiarray/ctors.c +++ b/numpy/core/src/multiarray/ctors.c @@ -31,6 +31,7 @@ #include "get_attr_string.h" #include "array_coercion.h" +#include "arrayobject.h" /* from this directory */ /* * Reading from a file or a string. @@ -552,6 +553,7 @@ PyArray_AssignFromCache_Recursive( else { PyArrayObject *view; view = (PyArrayObject *)array_item_asarray(self, i); + PyArray_CLEARFLAGS(view, NPY_ARRAY_WARN_ON_WRITE); if (view < 0) { goto fail; } diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 3f3bf9f70..b42d9e1f2 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -289,6 +289,7 @@ def _stride_comb_iter(x): # new array with different strides, but same data xi = np.empty(new_shape, dtype=x.dtype) xi.view(np.uint32).fill(0xdeadbeef) + xi.flags.writeable = True xi = xi[slices] xi[...] = x xi = xi.view(x.__class__) ```

It doesn't completely do the job it's supposed to do - it works for warning when doing a mutating operation that affects a view, however it also prevents regular slice assignment:

In [1]: x = np.ones((2, 3), dtype=np.float64)

In [2]: x[0, 0] = 1

In [3]: x[:2] = 1
... DeprecationWarning: Numpy has detected that you (may be) writing to an array with overlapping memory ...

In [4]: x.flags.writeable
Out[4]: True

In [5]: y = x[0, :]

In [6]: x.flags.writeable
... FutureWarning: future versions will not create a writeable array ...
Out[6]: True

In [7]: x.flags
Out[7]: 
  C_CONTIGUOUS : True
  F_CONTIGUOUS : False
  OWNDATA : True
  WRITEABLE : True  (with WARN_ON_WRITE=True)
  ALIGNED : True
  WRITEBACKIFCOPY : False
  UPDATEIFCOPY : False

In [8]: y.flags
Out[8]: 
  C_CONTIGUOUS : True
  F_CONTIGUOUS : True
  OWNDATA : False
  WRITEABLE : True  (with WARN_ON_WRITE=True)
  ALIGNED : True
  WRITEBACKIFCOPY : False
  UPDATEIFCOPY : False

That could be prevented by starting from within __setitem__ and passing down the information that the .writeable flag should not be touched to PyArray_SetBaseObject.

At that point there's still annoying behaviour left though, for example just evaluating x[:-1] in IPython will create a view that the interpreter holds on to - and hence flip the flag.

Trying out the effect on SciPy:

python -c "from scipy import ndimage; ndimage.test(extra_argv=['-W', 'error::DeprecationWarning'])"

gives 168 failed, 332 passed. Many failures are due to NumPy functions triggering the warning though - for example np.eye chokes on m[:M-k].flat[i::M+1] = 1 in its implementation (fixing that gets rid of 18 test failures already). Fixing that and a similar issue in np.indices gets me up to 126 failed, 374 passed. Still, that's a lot of failures - and some other SciPy modules won't even run with "error on deprecation", because something fails before pytest is done with test discovery.

tl;dr this is going to be a real pain to do.

rgommers commented 3 years ago

I wrote some tests that make it easier to figure out which libraries have matching behaviour for in-place operations:

```python """ To create an environment with the seven array libraries used in this script installed: conda create -n many-libs python=3.7 conda activate many-libs conda install cudatoolkit=10.2 pip install numpy dask toolz torch jax jaxlib tensorflow mxnet cupy-cuda102 Conda doesn't manage to find a winning combination here; pip has a hard time too and probably not all constraints are satisfied, but nothing crashes and the tests here work as they are supposed to. """ import numpy as np import dask.array as da import torch import tensorflow as tf import jax.numpy as jnp import mxnet try: import cupy as cp except ImportError: # CuPy is GPU-only, so may not be available cp = None import pandas as pd def materialize(x, y): if isinstance(x, da.Array): x = x.compute() if isinstance(y, da.Array): y = y.compute() if mod == mxnet.nd: x0 = int(x[0, 0].asscalar()) y0 = int(y[0, 0].asscalar()) else: x0 = int(x[0, 0]) y0 = int(y[0, 0]) return x0, y0 def ones(mod): if mod in (da, mxnet.nd): x = mod.ones((3, 2), dtype=np.int32) else: x = mod.ones((3, 2), dtype=mod.int32) return x def reshape(mod, x, shape): if mod == tf: return tf.reshape(x, shape) else: return x.reshape(shape) def arange(mod, stop): if mod == tf: return tf.range(stop) elif mod == mxnet.nd: return mod.arange(stop, dtype=np.int64) else: return mod.arange(stop) def diag(mod, x): if mod == tf: return tf.linalg.diag_part(x) else: return mod.diag(x) def slice_assign(mod, x): # Add 2 to first row of (3, 2)-shaped input x assert x.shape[0] == 3 idx = reshape(mod, arange(mod, 6), x.shape) < 3 if mod in (tf, da): x = mod.where(idx, x+2, x) elif mod == jnp: x = x.at[idx].set(x[idx] + 2) else: x[:, 0] += 2 return x def f1(mod): "Add, then in-place subtract" x = ones(mod) y = x + 2 y -= 1 return materialize(x, y) def f2(mod): "In-place add, alias, then in-place subtract" x = ones(mod) x += 2 y = x y -= 1 return materialize(x, y) def f3(mod): "Slice, then in-place add on slice" x = ones(mod) y = x[:2, :] y += 1 return materialize(x, y) def f4(mod): "Reshape, then in-place add" x = ones(mod) y = reshape(mod, x, (2, 3)) y += 1 return materialize(x, y) def f5(mod): "Slice with step size 2, then in-place add" x = ones(mod) y = x[::2, :] y += 1 return materialize(x, y) def f6(mod): "Alias, then in-place add" x = ones(mod) y = x y += 1 return materialize(x, y) def f7(mod): "Check which array types support slice assignment syntax" x = ones(mod) try: x[0, :] = 2 except (NotImplementedError, TypeError): x = -9*ones(mod) return materialize(x, 2*ones(mod)) def f8(mod): "Do the actual slice assignment in the way each library wants it done" x = ones(mod) x = slice_assign(mod, x) y = 2 * ones(mod) return materialize(x, y) def f9(mod): "`diag` is known to have inconsistent behaviour, test it" x = reshape(mod, arange(mod, 9), (3, 3)) y = diag(mod, x) if mod == np: y = y + 2 # `y` is read-only else: y += 2 return materialize(x, 2*ones(mod)) def f_indexing(mod, idx): "Indexing, then in-place add" x = ones(mod) try: y = x[idx, :] except TypeError: x = -9 * x y = ones(mod) y += 1 return materialize(x, y) def f10(mod): "Indexing with list of integers, then in-place add" return f_indexing(mod, [0, 1]) def f11(mod): "Boolean indexing, then in-place add" return f_indexing(mod, [True, True, False]) def f12(mod): "Indexing with ellipsis, then in-place add" return f_indexing(mod, Ellipsis) libraries = { 'numpy': np, 'pytorch': torch, 'MXNet': mxnet.nd, 'dask': da, 'tensorflow': tf, 'jax': jnp, } if cp is not None: libraries['cupy'] = cp results = libraries.copy() funcs = [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12] for name, mod in libraries.items(): results[name] = None res = [] for func in funcs: x, y = func(mod) assert y == 2 assert type(x) == int, (type(x), mod) res.append(x) results[name] = res results = pd.DataFrame(results).T results.columns = ['f{}'.format(i+1) for i in range(len(results.columns))] print(results.sort_values(by='f2')) print('\n') for i, func in enumerate(funcs): print(func.__name__ + ':', func.__doc__) if i == 5: # First six functions are only about in-place operators, no slice # assignment - separate those in the output print('') if i == 8: # Last three functions are only about indexing behaviour print('') print('\nCheck if behaviour equals that of NumPy\n(1 or -1 means copy/view behaviour ' 'mismatch, -9 means unsupported behaviour):\n') results_vs_numpy = results.sort_values(by='f2') - results.loc['numpy'] # Set NA values (from exceptions) back to -9 results_vs_numpy[results_vs_numpy < -5] = -9 print(results_vs_numpy) ```

Results in:

            f1  f2  f3  f4  f5  f6  f7  f8  f9  f10  f11  f12
numpy        1   2   2   2   2   2   2   3   0    1    1    2
pytorch      1   2   2   2   2   2   2   3   0    1    1    2
MXNet        1   2   2   2   1   2   2   3   0    1    1    2
cupy         1   2   2   2   2   2   2   3   0    1    1    2
dask         1   3   1   1   1   1  -9   3   0    1    1    1
tensorflow   1   3   1   1   1   1  -9   3   0   -9   -9    1
jax          1   3   1   1   1   1  -9   3   0    1    1    1

f1: Add, then in-place subtract
f2: In-place add, alias, then in-place subtract
f3: Slice, then in-place add on slice
f4: Reshape, then in-place add
f5: Slice with step size 2, then in-place add
f6: Alias, then in-place add

f7: Check which array types support slice assignment syntax
f8: Do the actual slice assignment in the way each library wants it done
f9: `diag` is known to have inconsistent behaviour, test it

f10: Indexing with list of integers, then in-place add
f11: Boolean indexing, then in-place add
f12: Indexing with ellipsis, then in-place add

Check if behaviour equals that of NumPy
(1 or -1 means copy/view behaviour mismatch, -9 means unsupported behaviour):

            f1  f2  f3  f4  f5  f6  f7  f8  f9  f10  f11  f12
numpy        0   0   0   0   0   0   0   0   0    0    0    0
pytorch      0   0   0   0   0   0   0   0   0    0    0    0
MXNet        0   0   0   0  -1   0   0   0   0    0    0    0
cupy         0   0   0   0   0   0   0   0   0    0    0    0
dask         0   1  -1  -1  -1  -1  -9   0   0    0    0   -1
tensorflow   0   1  -1  -1  -1  -1  -9   0   0   -9   -9   -1
jax          0   1  -1  -1  -1  -1  -9   0   0    0    0   -1

Some observations:

rgommers commented 3 years ago

I'm leaning towards the following choices:

  1. Remove out= completely, and don't replace it with anything
  2. Do not try to make arrays which have a view onto them read-only
  3. Allow usage of in-place operators and slice assignment
  4. Add a recommendation that users avoid any mutating operation when a view may be involved.
shoyer commented 3 years ago

Here is a hacky patch for NumPy that makes both the base array and the view read-only if a view is created:

Would it be more practical to make only the view read-only? If you modify the base array, then in addition to errors like the IPython example you show, functions that create views would now effectively be mutating their input arguments.

rgommers commented 3 years ago

Yes, that's a good point. I undid that change, and added some more fixes (current patch here). That at least makes the NumPy test suite start and run to completion without crashing. Resulting in:

1268 failed, 10174 passed, 67 skipped, 10 errors

The troubles are:

I think it would take too much time to complete this change and make the NumPy test suite (mostly) pass - and that'd be needed to be able to assess the impact on SciPy and other packages.

That, plus that it's a massive backwards compat break for all of NumPy, PyTorch, CuPy and MXNet, makes me think we should simply go with "Add a recommendation that users avoid any mutating operation when a view may be involved".

rgommers commented 3 years ago

Everyone in the meeting today was good with points 1-4 in https://github.com/data-apis/array-api/issues/24#issuecomment-689907016. I'll open a PR that discusses mutability, and @kgryte will remove out=.