data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
215 stars 45 forks source link

DLPack and readonly/immutable arrays #191

Closed seberg closed 8 months ago

seberg commented 3 years ago

Sorry if there is some more discussion (or just clarification) in the pipeline. But I still think this is important, so here is an issue :). I realize that there was some related discussion in JAX before, but I am not clear how deep it really was. (I have talked with Ralf about this – we disagree about how important this is – but I wonder what others think, and we need to at least have an issue/clear decision somwhere.)


The problem (and current state/issues also with the buffer protocol that supports "readonly") NumPy (and the buffer protocol) has readonly (`writeable=False` arrays). NumPy actually has a slight distinction (I do not think the matters, though): * Readonly arrays for which NumPy owns the memory. These can be set to writeable by a (power) user. * Arrays that are truly readonly. The current situation in the "buffer protocol" (and array protocols) world is that NumPy supports a writeable and readonly arrays, JAX is immutable (and thus readonly), while PyTorch is always writable: ```python import jax import numpy as np import torch ``` JAX to NumPy works correctly: ```python jax_tensor = jax.numpy.array([1., 2., 3.], dtype="float32") numpy_array = np.asarray(jax_tensor) assert not numpy_array.flags.writeable # JAX exports as readonly. ``` NumPy to JAX ignores that NumPy is writeable when importing (same for `memoryview`): ```python numpy_arr = np.arange(20).astype(np.float32) jax_tensor = jax.numpy.asarray(numpy_arr) # JAX imports mutable array though numpy_arr[0] = 5. print(repr(jax_tensor)) # DeviceArray([5., 1., 2.], dtype=float32) ``` PyTorch also breaks the first part when talking to JAX (although you have to go via ndarray, I guess): ```python jax_tensor = jax.numpy.array([1., 2., 3.], dtype="float32") torch_tensor = torch.Tensor(np.asarray(jax_tensor)) # UserWarning: The given NumPy array is not writeable, ... (scary warning!) torch_tensor[0] = 5 # modifies the "immutable" jax tensor print(repr(jax_tensor)) # DeviceArray([5., 2., 3.], dtype=float32) ``` And NumPy arrays can be backed by truly read-only memory: ``` arr = np.arange(20).astype(np.float32) # torch default dtype np.save("/tmp/test-array.npy", arr) # The following is memory mapped read-only: arr = np.load("/tmp/test-array.npy", mmap_mode="r") torch_tensor = torch.Tensor(arr) # (scary warning here) torch_tensor[0] = 5. # segmentation fault ``` This is in a world where "readonly" information exists, but JAX and PyTorch don't support it, and we leave it up to the user to know about these limitations. Because of that pretty severe limitation PyTorch decides to give that scary (and ugly) warning. JAX tries to protect the user during export – although not in DLPack – but silently accepts that it is the users problem during import. I do realize that within the "array API", you are not supposed to write to an array. So within the "array API" world everything is read-only and that would solve the problem. But that feels like a narrow view to me: we want DLPack to be widely adopted and most Array API users are *also* NumPy, PyTorch, JAX, etc. users (or interact with them). So they should expect to interact with NumPy where mutability is common. Also `__dlpack__` could very much be more generally useful than the Array API itself.

Both JAX (immutable) and PyTorch (always writeable) have limitations that their users must be aware of when exchanging data currently. But it feels strange to me to force these limitations on NumPy. Especially, I do not like them in np.asarray(dlpack_object). To get around the limitations in np.asarray we would have to:

Clarifying something like that is probably sufficient, at least for now.


But IMO, the proper fix is to add a read-only bit to DLPack (or the __dlpack__ protocol). Adding that (not awkwardly) requires either extending the API (e.g. having a new struct to query metadata) or breaking ABI. I don't know what the solution is, but whatever it is, I would expect that DLPack is future-proofed anyway.

Once DLPack is future-proofed, the decision of adding a flag could also be deferred to a future version…

As is, I am not sure NumPy should aim to talk preferably in __dlpack__ in the future (whether likely to happen or not). Rather, it feels like NumPy should support DLPack mainly for the sake of those who choose to use it. (Unlike the buffer-protocol, which users use without even knowing, e.g. when writing cython typed memoryviews.)

Neither JAX nor pytorch currently quite support "readonly" properly (and maybe never will). But I do not think that limitation is an argument against supporting it properly in __dlpack__. NumPy, dask, ~cupy~, cython (typed memoryviews) do support it properly after all. It seems almost like turning a PyTorch "user problem" into an ecosystem wide "user problem"?

Of course, NumPy can continue talking buffer-protocol with cython, and many others (and likely will do in any case). And I can live with the limitations at least in an np.from_dlpack. But I don't like them in np.asarray(), and they just seem like unnecessary issues to me. (In that case, I may still prefer not to export readonly arrays.)


Or am I the only person who thinks this is an important problem that we have to solve for the user, rather than expect the user to be aware of the limitations?


EDIT: xref discussion about it at cupy, that was mentioning that nobody supports the readonly flag which __cuda_array_interface__ actually includes, and asked for support to cupy. (I am not sure why C-level matters – unless it is high level C-API – the Python API is what matters most? NumPy has a PyArray_FailUnlessWriteable() function for this in the public API.)

leofang commented 3 years ago

Just a drop-by comment (on my cell): CuPy currently does not support the readonly flag and does not have any plan afaik.

rgommers commented 3 years ago

Thanks for the summary Sebastian!

Arrays that are truly readonly.

Let me add one other example you gave to make it more concrete: an ndarray that's backed by a HDF5 file opened in read-only mode.

There's other ways to get read-only arrays (memmap, broadcast_arrays, diag), but I think the HDF5 one will be the most widely used one.

I do realize that within the "array API", you are not supposed to write to an array. So within the "array API" world everything is read-only and that would solve the problem.

I think you wrote this with a certain scenario in mind, but just to be sure: this is not the case in general. There are in-place operators, and item and slice assignment is supported. Also, DLPack is a memory sharing protocol, and there's in principle nothing that forbids writing to that shared memory.

The only thing the array API says is:

This is in a world where "readonly" information exists, but JAX and PyTorch don't support it, and we leave it up to the user to know about these limitations. Because of that pretty severe limitation PyTorch decides to give that scary (and ugly) warning. JAX tries to protect the user during export – although not in DLPack – but silently accepts that it is the users problem during import.

I agree, this is not ideal. It's also very much nontrivial to implement (see, e.g., https://github.com/pytorch/pytorch/issues/44027). For PyTorch it may still come at least, for JAX and TensorFlow I don't think there's any chance of that. So given that this doesn't exist in most libraries, the options are:

  1. emit a warning and continue
  2. raise an exception, force the user to make a copy
  3. silently make a copy before exporting memory, to be safe
  4. just continue, and document what the user should/shouldn't do
  • Import all DLPack arrays as readonly or always copy? This could be part of the standard. In that case it would make the second point unnecessary. But: That prevents any in-place algorithms from accepting DLPack directly.

Import as read-only is possible for NumPy, but not for PyTorch - so it's not a general solution. Always copy goes against the main purpose of the protocol I think.

  • Possibly, export only writeable arrays (to play safe with PyTorch). Seems fine to me, at least for now (a bit weird if combined with first point, and doesn't round-trip)

This makes sense to me, in particular for arrays that are backed by memory-mapped files. There's other cases here for other libraries, like lazy data structures. If the protocol is for in-memory data sharing, the data should be in memory. That does still leave the question if the copy should be made by the library, or if the library should raise and make the user do it.

rgommers commented 3 years ago

NumPy, dask, ~cupy~, cython (typed memoryviews) do support it properly after all. It seems almost like turning a PyTorch "user problem" into an ecosystem wide "user problem"?

Dask also doesn't quite support it, right? It's only able to hold numpy arrays with the writeable flag set, but there's no support at the dask.array level. So really out of all array/tensor libraries, it's NumPy that supports both modes, and PyTorch/JAX/TensorFlow/CuPy/MXNet/Dask don't.

And even in NumPy it's a fairly niche feature imho - I'd say the "open a file in read-only mode" is the main use case. It's important enough for people that do a lot of their work with HDF5/NetCDF files to carefully consider, but I'd not say this is an "ecosystem wide user problem".

But IMO, the proper fix is to add a read-only bit to DLPack (or the __dlpack__ protocol).

This would not help much compared to the situation now. Take a library like PyTorch, what is it supposed to do when it sees a read-only bit? It has no way to create read-only tensors, so it still has exactly the same set of choices 1-4 as I listed above. Same for JAX - it only has one kind of array, so nothing changes.

kgryte commented 2 years ago

DLPack support was removed from the specification for asarray in https://github.com/data-apis/array-api/pull/301.

Concerns are still applicable for the buffer protocol and its support for "readonly".

However, it would not seem that there's anything further we need to address in the spec. If there is, feel free to reopen this issue to continue the discussion.

rgommers commented 2 years ago

DLPack support was removed from the specification for asarray in #301.

We had some more recent discussion in a NumPy community meeting, and I think that aligned with the outcome here - removing DLPack support from asarray and leaving it to a specific separate function (from_dlpack) makes the introduction safer (since asarray is already widely used).

Concerns are still applicable for the buffer protocol and its support for "readonly".

However, it would not seem that there's anything further we need to address in the spec. If there is, feel free to reopen this issue to continue the discussion.

Good point. Since it's older, there are no/fewer concerns about the readonly behavior of the buffer protocol, but yes it's a potential footgun still. Anyway that existed for a long time before, so I'm happy to leave that as is and leave it up to libraries to deal with the readonly flag by raising a warning, an exception or by ignoring it as they see fit.

leofang commented 2 years ago

Reopen this issue to address readonly tensors. As discussed in the Jan-6-2022 meeting, it is preferable to still export readonly tensors and let the consumer decide how to handle in some way. So, the above comment

Anyway that existed for a long time before, so I'm happy to leave that as is and leave it up to libraries to deal with the readonly flag by raising a warning, an exception or by ignoring it as they see fit.

would not be fully applicable. The exporter should not raise.

This issue will close numpy/numpy#20742.

rgommers commented 2 years ago

The exporter should not raise.

Agreed, with "libraries" I meant importers. The buffer protocol has a readonly flag, so there's no problem for exporters.

pearu commented 10 months ago

I have been reading lengthy discussions about the problems of sharing buffers between different array libraries: all points are very much appreciated.

It seems to me that the fundamental cause of the issue boils down to "incomplete" design choices of these libraries when considering the problem of buffer sharing, even when their design may be complete within the objectives of a particular library. I have a couple of suggestions how to alleviate this without requiring changes in the design of the libraries but rather introducing an intention bit to from_dlpack function (see at the end of this discussion). Back to the fundamental cause...

For instance, PyTorch tensors are always writable that enables efficient usage of computational resources, however, practice has shown that such a design is far from optimal as there exists use cases also within the objectives of PyTorch where one would wish to set PyTorch tensors non-writable. As an example, the indices of sparse tensors must be immutable by definition but currently it is fairly easy to crash Python as the indices can be arbitrarily mutated by users or as some side-effect of a software bug (see e.g. https://github.com/pytorch/pytorch/issues/44027#issuecomment-1247735981).

Jax represents another side of the spectrum where arrays are immutable by design of operating with pure functions. But even here one could argue that the Jax design is incomplete with respect to tasks that do not require pure functions assumption, say, initializing an array with computation results that may be obtained from using some another array library.

In general, the life time of an array buffer consists of (i) memory allocation, (ii) initialization, (iii) updates, (iv) usage, and finally, (v) memory release. In life stages (ii) and (iii), the arrays must mutable while in the stage (iv) the array ought to be immutable. Allowing multiple switching between updates and usage stages is expected. Sharing array buffers in all of the stages (ii), (iii), and (iv) make practical sense, IMHO.

In the terms of the array life time, the current PyTorch implementation considers immutable array usage irrelevant while the current Jax design fuses initialization into array construction phase and disallows updates. Disclaimer: I don't mean to imply that these libraries ought to change their behavior, although, in some cases such as PyTorch, it would be beneficial.

I would expect that the current design choices of libraries (that may change in time in any direction) ought not to restrict array API standard but rather unify these. For instance, extending dlpack protocol with a read-only flag makes lot of sense to me as it will enable implementing sensible buffer sharing approaches between array libraries that otherwise may have contradicting designs.

For instance, when a jax.Array buffer is exported via dlpack protocol with read-only flag set to True, the PyTorch ought to have a choice to import it via copy (when it is known that PyTorch is going to update buffers) or via view (when it is known that PyTorch is going to use the buffer as input data only). The choice can be defined as a keyword argument to from_dlpack function. If the expectation and the state of the read-only flag contradict, from_dlpack would raise an exception. (Btw, the copy approach is not necessary as it is equivalent to requesting a view that must follow a clone operation and all that ought to be implemented by PyTorch). What do you think?

The current behavior is unhealthy in many ways indeed. For instance, dlpack protocol can be misused to expose immutable data to mutations (another example is https://github.com/google/jax/issues/19123). Or, when jax.Array is exported to PyTorch via __cuda_array_interface__ then an exception is raised, although, PyTorch program may have no intention to mutate the exported buffers (read: dlpack with read-only flag should not copy this behavior, instead, it should allow specifying the intention of the importer and respect it when sensible).

pearu commented 10 months ago

For instance, extending dlpack protocol with a read-only flag...

dlpack already provides the read-only flag in the new DLManagedTensorVersioned struct: https://dmlc.github.io/dlpack/latest/c_api.html#_CPPv424DLManagedTensorVersioned

leofang commented 10 months ago

dlpack already provides the read-only flag in the new DLManagedTensorVersioned struct:

Yes but we need a handshaking mechanism in Python before introducing the support, part of which is what #602 intends to address.

rgommers commented 10 months ago

Agreed, we should finally get gh-602 merged.

I don't think simply having access to that flag matters too much for readonly behavior by itself though. I don't quite agree that the current state is that unhealthy. There aren't many bug reports, and for those that there are it's usually user error in the same way as doing a mutation on a view by accident in pytorch-only or numpy-only and that modifying the base tensor/array as well. https://github.com/google/jax/issues/19123 is also simply that, it is not a real world bug.

dlpack with read-only flag should not copy this behavior, instead, it should allow specifying the intention of the importer

The problems invariably occur by accident. So if PyTorch had true read-only tensors, that would help increase robustness against unintended mutations a lot. And then no new keyword is needed, JAX arrays would always translate to PyTorch readonly tensors. But if it's up to the user to write from_dlpack(x, readonly=False), that is already possibly since the behavior you suggest is equivalent to copy(from_dlpack(x)).

pearu commented 10 months ago

Agreed that https://github.com/google/jax/issues/19123 or any user misuse of mutating immutable data can be considered not a real bug when it is documented as undefined behavior. If it is not a bug then it must be a weakness of the co-operativity between array libraries that (i) allows misuse and (ii) there exists no good alternative that enable strict correctness - existing arrays libraries are not always compatible (yet) in exchanging data bidirectionally.

Re consumer.copy(consumer.from_dlpack(x)): this is correct iff the consumer library supports read-only arrays. PyTorch does not support this and currently torch.from_dlpack throws an exception on immutable (e.g. jax) arrays so that torch.clone will never got a chance. So, people invent workarounds such as using more relaxing libraries to hide the immutability of the provider [this is what I mean by unhealthy] (e.g. using cupy https://github.com/pytorch/pytorch/issues/32868#issuecomment-593764232) but in time these loopholes may be closed (IIRC, cupy also plans to throw on readonly arrays).

rgommers commented 10 months ago

it must be a weakness of the co-operativity between array libraries that (i) allows misuse and (ii) there exists no good alternative that enable strict correctness

I'd say it's a weakness (limitation is a nicer word probably) of the library in itself. There really isn't any difference between this unwanted mutation happening in CuPy-only or PyTorch-only code (e.g., the sparse tensor indices example you gave) vs. unwanted mutation on a tensor that happened to come in via DLPack.

currently torch.from_dlpack throws an exception on immutable (e.g. jax) arrays

I don't think that's the case? There's no sign of read-only handling in dlpack.h or DLConvertor.cpp, and the docstring correctly explains that the data is always exposed as shared memory. I did a quick check with latest JAX + PyTorch 2.0, and things work as expected, no exception.

An exception is thrown for NumPy read-only arrays though (but those are quite rare, and the user can simply flip the flag on the input numpy array if this happens): https://github.com/pytorch/pytorch/blob/035e55822ad123b02fbd9d91e1185f5c07af0ddd/torch/csrc/utils/tensor_numpy.cpp#L427-L429.

IIRC, cupy also plans to throw on readonly arrays).

That would be the wrong decision I think, that would make interop between CuPy and JAX completely useless/impossible.

The fact remains that if you have one library that is only read-only and one that is never read-only, you only have only a couple of choices independent of whether there are flags/keywords:

  1. make a copy (performance hit)
  2. expose shared memory (robustness of the always-readonly lib is a bit degraded, now susceptible to user errors)
  3. raise an exception (functionality taken away)

JAX used to do (1) and users complained about this, so JAX changed to (2). This works much better, with few complaints AFAIK. So it looks to me like, in the absence of evidence of too many bug reports, that state should stay that way. You simply cannot have complete robustness as well as zero-copy behavior with the current state of JAX -> PyTorch/CuPy interop. Forcing users to make a copy when a readonly flag is set would just go back to (1).

Adding a keyword to from_dlpack is at best a way to emphasize to the user that the docs warn about mutation, and let them say "yes I read the docs, and I promise I won't mutate this data". But they're still at (2), and tracking views in a large program can be tricky, so robustness is still imperfect just like it is now.

pearu commented 10 months ago

currently torch.from_dlpack throws an exception on immutable (e.g. jax) arrays I don't think that's the case?

Sorry, my bad. It was torch.asarray throwing on __cuda_array_interface__ with read-only flag set to True.

rgommers commented 10 months ago

Ah yes, that is an inconsistency - probably should be changed in PyTorch to align with the DLPack behavior.

pearu commented 10 months ago

While we are at (2), it is because (i) the current DLManagedTensor does not have read-only flag and (ii) the view-functionality is required for performance. With DLManagedTensorVersioned, there is a risk that dlpack consumers will respect the read-only flag despite of (ii) and will throw similar to the consumers of Numba CUDA/NumPy array interfaces. Otherwise, what is the point of setting readonly flag if consumers will ignore it?

rgommers commented 10 months ago

We should add advice not to do that then.

Otherwise, what is the point of setting readonly flag if consumers will ignore it?

It's of use mainly if a library actually has a read-only mode. So it will definitely help if NumPy is the consumer, but it won't change too much for PyTorch/CuPy.

leofang commented 8 months ago

FYI CuPy is investigating the readonly support: https://github.com/cupy/cupy/pull/8118.

I think with #602 merged, we can close this issue now?

rgommers commented 8 months ago

FYI CuPy is investigating the readonly support: cupy/cupy#8118.

Interesting, thanks for sharing.

I think with #602 merged, we can close this issue now?

Agreed. We have a readonly flag in the latest DLPack version now, plus a way to upgrade to it with max_version (gh-602), so it looks like from the perspective of the standard there is nothing left to do.

lucascolley commented 8 months ago

To be clear, does from_dlpack(some_read_only_array) make a copy?

rgommers commented 8 months ago

No, you have to set copy=True to make a copy. "readonly" is simply a flag at the C level, which should be honored if both producer and consumer support read-only arrays.

seberg commented 8 months ago

The answer right now is: undefined. Unless you pass copy=True/False, the return value may or may not be a copy (preferentially not, but no guarantees).

lucascolley commented 8 months ago

Ok. What is the correct way to handle the following consumer scenario:

If we only use np.asarray, we error out on read-only arrays from JAX. If we always copy (e.g. with from_dlpack(x, copy=True)), we introduce redundant copies whenever the input is already a NumPy array.

I suppose I am looking for something along the lines of the copy=None 'only if needed' semantics, while still avoiding device transfers.

(Perhaps the answer is that we should just use np.asarray and error out for JAX)

rgommers commented 8 months ago

I suppose I am looking for something along the lines of the copy=None 'only if needed' semantics, while still avoiding device transfers.

If you don't explicitly use the device keyword, you won't get a device transfer. So just from_dlpack(x, copy=None) will do this.

If we only use np.asarray, we error out on read-only arrays from JAX

I'm not sure why that would be. If that's actually happening, it's probably a bug (assuming they're CPU arrays).

Nothing should be missing; if there's a real-world problem can you give some actual code?

lucascolley commented 8 months ago

Checkout the first commit of scipy/scipy#20085 :)

rgommers commented 8 months ago

The first commit contains nothing related, the second commit changes asarray to copy. It's not clear to me why, asarray works:

>>> import numpy as np
>>> import jax.numpy as jnp
>>>
>>> x = jnp.arange(3)
>>> y = np.asarray(x)
>>> y.flags.writeable
False
lucascolley commented 8 months ago

Sorry, the asarray itself doesn't throw errors (and keeps writable False). My question is, do we want to make a copy here so that we can go on to use np in-place ops? Or do we just let exceptions be raised, similarly to if a device copy would be needed?

(Checking out the first commit allows you to see the errors which occur without introducing the copys, with python dev.py test -s cluster -b jax)

pearu commented 8 months ago

"readonly" is simply a flag at the C level, which should be honored if both producer and consumer support read-only arrays.

I think this statement ought to be more explicit in the standard, say, via an additional note:

Otherwise, a non-copy exchange of data becomes impossible between producers that do no support writable arrays (say, JaX) and consumers that do not support read-only arrays (say, PyTorch).

rgommers commented 8 months ago

Sorry, the asarray itself doesn't throw errors (and keeps writable False). My question is, do we want to make a copy here so that we can go on to use np in-place ops? Or do we just let exceptions be raised, similarly to if a device copy would be needed?

I'll have a look, thanks. This isn't really DLPack related, because the exact same thing should happen with NumPy read-only arrays. Also, with a few exceptions, all APIs in SciPy should not be modifying their input arrays, so I guess there'll be a zero-copy way to work around the exact problem here.

Let's continue this on the SciPy PR - I'll try your code and comment there.

I think this statement ought to be more explicit in the standard, say, via an additional note:

Sure, that seems useful to me. That'll be a note that simply documents the way things currently work anyway I believe. Let me open a PR for that. "should" as a recommendation is probably right here, rather than "must" - but let's leave that level of detail for the PR.

rgommers commented 8 months ago

That'll be a note that simply documents the way things currently work anyway I believe. Let me open a PR for that.

Done in gh-749.