dmlc / dlpack

common in-memory tensor structure
https://dmlc.github.io/dlpack/latest
Apache License 2.0
890 stars 135 forks source link

Specify synchronization semantics #57

Closed kkraus14 closed 3 years ago

kkraus14 commented 3 years ago

From the CUDA side, nearly all of the DL frameworks and array libraries that support dlpack use CUDA streams and stream order both their computations and memory allocations. In its current form, dlpack doesn't specify any synchronization semantics as well as doesn't have a way to specify information to allow for a producer-consumer pair to exchange the necessary information to continue to stream order computations.

I imagine there's a similar problem in other contexts as well (OpenCL, ROCm, etc.) where maybe it's possible to generalize an approach.

cc @oleksandr-pavlyk

veritas9872 commented 3 years ago

I am also curious as to whether DLPack is blocking.

leofang commented 3 years ago

I would like to add that another data exchange formant, CUDA Array Interface, has a recent revision that @kkraus14 and I and many others have contributed to, see the Synchronization section from its v3 protocol. It turns out nontrivial to ensure the data being exchanged is correct with respect to the semantics of CUDA (or HIP or OpenCL, if one cares). It appears that DLPack currently lacks this level of consideration, and assumes implicitly that the Producer and the Consumer exchanging the data live on the same (whatever) stream/queue. This is a very strong assumption with hindsight.

I understand no one has complained about lacking such a synchronization requirement in DLPack (as this should have been a serious issue since the advent of DLPack, but in practice even on the CUDA Array Interface side we never receive a single bug report for not guaranteeing the synchronization behavior), but if the Python Array API Consortium is going to adopt DLPack as the official zero-copy standard, this will need to be rigorously sorted out as done in CUDA Array Interface v3.

cc: @rgommers (for awareness)

veritas9872 commented 3 years ago

I think that making synchronization optional might also be an option. In most high-level CUDA tasks, as @leofang has already mentioned, most people do use the array on the same stream. However, forcing synchronization is time-consuming and goes against one of the main benefits of CUDA programming, host/device asynchronous behavior.

tqchen commented 3 years ago

Thanks @leofang @kkraus14 . The issue of synchronization has been thought of, and it is certainly possible to come up with a specification that allows more asynchronize behavior. See also my previous related comment to this issue here https://github.com/data-apis/consortium-feedback/issues/1#issuecomment-675230486

The main design goal is mainly to get folks to agree on common set of things. The main problem is not about a solution for synchronization, but to get frameworks to agree on the mechanism.

Most modern frameworks have their own opion about allocation and scheduling, in many cases these opinions can vary. For example, both TF-RT and MXNet has its own internal scheduler. New abstraction are also being proposed (e.g. the case of graph base scheduling).

Based on that observation, we concluded that it is hard to get people agree on the synchronization and allocation mechanism (yet), and resort to a simpler convention -- we require producer and consumer to be on the same stream.

Most protocol that brings synchronization behavior will inevitably embed stream, or context into the data structure. This would defer synchronization burden to the runtime implementation of the system(consumer). While on one hand it is certainly more flexible, it does bring engineering cost to the developers who are supposed to consume the DLPack array.

Additionally, fewer expert programmers may want to think about synchronization, while there are broader range of programmers who may not even consider this perspective and only uses one stream throughout ther applications. Enabling context in DLPack means this complexity falls into the crowd, rather than a few experts. It also brings extra burden to the compilers themselves. The compiler will need to generate optional synchronization code based on the streams, which is non-trivial.

Finally, the simple protocol (of asking producer/consumer to be on the same stream) will not prevent async computation. As no synchronization is needed if both are on the same stream(for most apps they are). Alternative APIs can also be designed to request DLPack array that is presented in certain context, for example an optionally API below:

array.to_dlpack_on_stream(stream_number)

Such API will put the burden of sychronization to the producer, which is not necessarily more complex than putting things to the consumer, but reduces the metal burden on the consumer side(think about common developers who might want to consume the array).

To summarize, I think it is good to have a discussion about synchronization API. However, based on the current state, I think it would be useful to de-couple the discussion of synchronization(which is harder to get consensus) from the data structure(which most frameworks already agree on). I would also argue that simplicity is not necessary a bad thing in here, since we could build APIs that as powerful, but drives adoption in common cases

tqchen commented 3 years ago

To summarize things further, let us discuss the following two variants of possible APIs to support sychronization

From the flexibility PoV. S1 style API is certainly more flexible. For example, the consumer can choose to continue run computation on the streams provided by the producer.

The reality is, however, most of the application/framework developers (rightfully) have opinions about what streams to run computation on(that are attached to their own internal scheduler). So the best thing consumer could do is to run a dependency sync to bring the array to the internal stream then consumes the data, if the developer knows how to do so. In many cases developers may not even want to think about synchronization at all and operates on the default stream -- we certainly want to be able to support these developers when possible. S1 also brings the burden to consider the complication when say two operands sit on different streams.

In the case of S0 style API simplicity is also a merit -- so more developers can agree on and implement such a protocol correctly. While being limited, it won't slows down computation, because it can also support asynchronize computation when possible. It also separates the opinion about data structure from the opinion of sychronization.

leofang commented 3 years ago

Hi @veritas9872 @tqchen Thanks for your replies! Looks like @tqchen you have put serious thoughts on this, and it's great to have you write it up here for future reference 🙂

I would like to kindly nudge you to read the Synchronization section of the CUDA Array Interface (CAI) v3, as many factors you raised are already considered throughly and covered there. An over-brief summary is below:

All these considerations together guarantee the correctness of computations, which was deemed critical across the libraries at the time (PyTorch, TF, Numba, CuPy, RAPIDS, etc), while leaving flexibilities to expert users, in particular advanced HPC/ML users. In the CAI documentation a very complicated example was given to demonstrate that without such a synchronization requirement a computation would almost certainly go wrong.

My conjecture for CAI (before v3) and DLPack to work well so far is most libraries live on CUDA's legacy default stream, which has an implicit synchronization/serialization for kernels launched on it. But given that more libraries are adopting the per-thread default stream, this would eventually become a real concern.

If DLPack is going to assume the Producer and Consumer need to live on the same stream, it is fine to me, but it should be clearly stated in the documentation, which is currently not the case. Also, we would need a way to make sure of this if correctness is of top concern, which circles back to the 2nd point above that we need to access the Producer's stream pointer.

The main problem is not about a solution for synchronization, but to get frameworks to agree on the mechanism.

I totally agree. We had a very hard time when revising v3. Library and framework maintainers have their own opinions.

But it's very odd to me that you mentioned TF, because TF was one of the main drivers that were unwilling to adopt CAI unless we specify clearly the synchronization behavior (hence leading to the v3 revision), so it is difficult for me to understand how come they are OK with DLPack's (currently) unsafe behavior but gave us tons of headache (and in the end it's still unclear to me if they'd implement CAI v3).

On the API consideration, I think one important distinction between CAI and DLPack is that we do not expose the intermediate object to the User, as all of the information are simply dumped to a dict and generated (consumed) internally in Producer (Consumer). The benefit is threefold:

So I think this is the reason that you have to consider the S0 vs S1 styles above. To us it's really not an issue if the struct could have an additional field to hold the stream pointer (following S1), and Users can simply write (pseudo) code as simple as this (see the CAI doc for a better, practical example)

a = producer.array([1, 2, 3])
b = consumer.asarray(a)  # <-- sync happens here, if a stream pointer is given

that is, there is no explicit calls to to_CAI() and from_CAI(); the struct lives in the dict a.__cuda_array_interface__. All the burdens of syncing are on the Producer/Consumer maintainers, but it's really minimal. For example CuPy's Producer code is here, and Consumer code is here. Just a few lines. In fact I've seen shorter ones, so to me it should not be an obstacle (disclaimer: I maintain these codes). (In comparison, CuPy's to_dlpack() is here and from_dlpack() is here, many more lines....)

It is necessary to revise DLPack's struct. Consider we add an optional stream argument to to_dlpack() (but not to from_dlpack() since all it matters is the Producer's stream), following more or less your S0 style:

def to_dlpack(stream=None):
    pass
def from_dlpack(object):
    pass

this has three disadvantages:

  1. it puts the burden of managing streams on the Users, which I don't appreciate
  2. it's not necessary possible for the Producer to expose its internal stream handles (as in TF's case), so Users just cannot access it
  3. the synchronization would happen when the DLPack intermediate object is generated, not when it is actually consumed, because the stream pointer is not carried by the object

Given the above reasons, I just don't think your S0 style would work if steam synchronization is considered an issue.

rgommers commented 3 years ago

On the API consideration, I think one important distinction between CAI and DLPack is that we do not expose the intermediate object to the User

A note just on this: we discussed that recently and agree to change the Python API to from_dlpack(x) and have x.__dlpack__ in order not to have a user-level capsule object: https://github.com/data-apis/consortium-feedback/issues/1#issuecomment-726111658.

szha commented 3 years ago

Taking a step back from CUDA's stream, for supporting synchronization I think DLPack may need to support other modes of scheduling as well (e.g. barrier). It might make sense to start with a more limited scenario first (i.e. always synchronize) that can be implemented in all abstraction and then add more interfaces to relax it.

tqchen commented 3 years ago

Right now the options are clearly summarized as per S0 and S1 design choices. The choice of python API has converged on what @rgommers described.

The synchronization perspective of CAI mentioned by @leofang is summarized in S1(I tried to summarize it after reading the CAI v3 proposal). Right now, dlpack adopts a simple synchronization semantics(per S0) is specified as

To keep things simple, producer and consumer are required to operate on the same stream,
otherwise an explicit synchronization is needed to make the data visible to the right stream.

It would be great to try to talk about S0 and S1 style API further, perhaps from the framework implementer's PoV.

tqchen commented 3 years ago

Here is my (personal) take on S1 style API(e.g. CAI v3) . First of all, it is certainly possible to specify the runtime dependent synchronization behavior correctly in a S1 API (CAI v3 should be such an example).

Noteably, such specification also brings in additional complications to the implementer that consumes a DLPack tensor. Specifically the following scenarios:

Both K0 and K1 are possible barriers for a fully compatible implementation.

In contrast, S0 style API, while being limited, is simpler to implement overall. We are more likely going to get fully compatible implementations from frameworks.

So our question here is not really about which one is more advanced and flexible(we all agree that S1 is more flexible and "advanced", while S0 is simpler and limits the possible ways to synchronize), but about the tradeoff between simplicity and flexibility.

szha commented 3 years ago

The reason I don't think either S0 or S1 reached the heart of the problem is that they differ only in where to store the reference that's necessary for synchronization, for which we could coordinate as part of defining the synchronization semantics in DLPack. To me, the key difference between the current DLPack and the potential synchronization extension is that in asynchronous setting the provider needs to return a promise object. Such promise object can be a reference to which frameworks on supplying and consuming ends have common knowledge on how to synchronize, such as a stream number, but it's not always true. It's fully possible that the provider supplies such promise whose synchronization depends on a mechanism that the consumer has no knowledge of. More generally, the synchronization methods need not couple with the memory. For example, two imaginary frameworks TF and MX are running on a platform with both TPU and GPU, and TF passes an array promise to MX of which the computation happens on TPU first. It's fully possible that TF is the only framework that has knowledge on how to implement synchronization on TPU. Requiring frameworks to know how to synchronize may not even be possible, which means that DLPack may have to choose not to support such hardware even if the memory space is addressable by all frameworks.

In general, I think it makes sense for the promise object to supply a query function on whether the data is ready or not. Beyond that, I think of CUDA stream, or other synchronization methods commonly supported on multiple frameworks, more of special cases.

leofang commented 3 years ago

Wait...this brings up a question I've been wondering: What's DLPack's requirement for cross architecture transfer (say cpu <-> gpu)? I thought for the purpose of enabling zero copy with no surprise, it should error out and such transfers must be done by other means (usually a sane library would implement it). Is such behavior documented somewhere?

szha commented 3 years ago

I believe DLPack only does one job at the moment, which is to facilitate the exchange of data in commonly addressable memory spaces. It does not yet handle data transfer across devices. What I described above is the scenario where the final memory address is being filled asynchronously as part of the final step of computation on a special device. ~I think this will come up even in a future version of CUDA when the async mem copy is supported.~ This is now already supported in CUDA 11.1.

leofang commented 3 years ago

Yeah I think cudaMallocAsync etc appear in CUDA 11.2.

leofang commented 3 years ago

The scenario is a more complex version of what was considered in CAI v3, so in this regard I agree with you that CAI v3 is a special case, but well it was meant for CUDA only 😗 I brought up CAI v3 is mainly to emphasize and echo the original issue that synchronization (or not) behavior at the interoperating points should be considered, determined, and documented.

szha commented 3 years ago

Yes, for practical reasons I think it's valid (and most likely desired), and I'm happy to continue discussion on coordinating CUDA synchronization as a special case. My above comments are meant to make it clear that the same abstraction doesn't necessarily apply generally.

tqchen commented 3 years ago

Great discussions, personally i do not have an opinion on style of possible synchronizations. Although both memory copy and other (allocations) all can have a stream based abstraction, so the S0 specification is still sound (although limited).

One observation from the current discussion is that while most of us agrees on data structures, there are many opinions about possible synchronization API(e.g. explicit specify string, stream based, or opaque future/promise). Each of them would bring some level of complexity and pros and cons.

My main point is to ideally de-couple the standardization of synchronization (which is more complicated) from standardization of data structure itself.

This would however, means we may not to couple the synchronization related data structure(e.g. stream) into DLTensor and limit the synchronization to a S0-level requirement at DLPack level. I would also argue that the S0-level semantics is simple and self-contained for now, can handle async settings(via default stream or same stream).

Additional discussions about synchronization API standardizations are great we could open a separate thread for that :)

tqchen commented 3 years ago

Let me attempt to summarize the proposals so far and their pros and cons

S0: Producer Sync based on Stream Specified by Consumer

def __dlpack__(self, stream=None):
    """Get a DLTensor capsule that is visible in stream.

    The producer will take charge to do dependency synchronize to stream.
    If no stream is None, then it defaults to the default stream(0 in cuda).
    """
    pass

Pros

Cons:

A further simplification of S0 would be requiring the exchange to always synced to the default stream (Note this is still different from explicit synchronization), this would remove the stream signature from the dlpack function

S1: Producer Stores the stream in the Exchange Data Structure, Defer Sync to Consumer

Mostly from API style described by @leofang, CAI style synchronization.

class DLManagedTensor:
    stream: int

def __dlpack__(self):
    """Return an exchange structure that stores the stream,"""
    pass

Pros:

Cons:

S2: Producer Stores an opaque Promise object in the Exchange Data Structure, Defer Sync to Consumer

Proposal by @szha

class DLManagedTensor:
    sync: function<()>
    ready: function<()>

def __dlpack__():
    """Return an exchange structure that stores the stream,"""
    pass

Pros:

Cons:

tqchen commented 3 years ago

My current proposal is to start simple, and adopt a S0 style API. The simplicity and benefit bought by A1 should be sufficient for most of the high performing cases.

In reality I believe default stream is the only one that different frameworks could have agreed to, given framework have their opinions about internal stream context. This means the power of B0 in S1 is rarely used.

Additionally, A2 is critical for framework adoptions, as we want simple support to be implemented correctly for most of the cases. While additional effort can be put into bringing in stronger support if necessary.

rgommers commented 3 years ago

S0 with def __dlpack__(self, stream=None): sounds good to me - simplicity and correctness while remaining performant for most cases seems like a good trade-off. A couple of questions:

tqchen commented 3 years ago

@rgommers it applies to driver API that have a global stream space that can be converted to integer, so for now it is CUDA and rocm case. In such cases, stream are guaranteed to be the same thing in the same process. For other driver APIs, the producer and consumer need to agree on the convention of "default stream".

The lack of global default is a limitation of other APIs considering the need of exchange -- in the scenario of single app of course being able to create app-specific context was considered as being flexible. However such flexibility limits the sharing of context between apps bought by standardization. The producer and consumer need to agree to default context in this case. A possible extension later is could be try to create such a global context table for other device APIs, where each application can access.

'None'(instead of '0') indicates the default stream, so it is unambigous. But I believe '0' is also fine.

You are right that from_dlpack does not have to contain stream signature. So user does not need to think about it.

kkraus14 commented 3 years ago

If stream isn't contained in the dlpack structure (and assumedly not cleaned up in the deleter function) and we're passing the stream via an integer, who is responsible for maintaining the lifetime of that stream? I.E.

import my_array_lib

my_arr = my_array_lib.create_array_on_stream([1, 2, 3], stream=my_array_lib.create_stream())
# my_arr internally holds a reference to the passed in `stream` object for lifetime management

my_dlpack = my_arr.to_dlpack(stream=int(my_arr.stream))
my_arr = None
# What state are we in now?

From what I can tell, the only way to guarantee that the stream isn't prematurely destroyed here is to either synchronize it before creating some_other_arr, or hold a reference to some_other_arr. If the handoff was to a C library then there's no clean way for it to hold that reference.

dlpack already has lifetime management semantics for the memory in the form of the deleter, where I'd argue we need similar for any stream resources used. We could say that synchronization needs to happen by producers before handing off the dlpack_tensor object, but then this introduces a lot of unnecessary synchronization and makes dlpack arguably less effective in acting as a good interchange mechanism between libraries.

rgommers commented 3 years ago

We could say that synchronization needs to happen by producers before handing off the dlpack_tensor object, but then this introduces a lot of unnecessary synchronization and makes dlpack arguably less effective in acting as a good interchange mechanism between libraries.

Isn't this just a way of saying "I prefer S1 over S0", or " I want optimal performance even when both libraries use non-default streams, at the cost of implementation complexity"?

The trade-off here has some unstated assumption on how often this kind of data interchange is used. If one is gluing together two ML models by exchanging some final or intermediate output once in a while, and the occasional extra sync if both of those models use non-default streams isn't a big deal. If you have use cases that use from_dlpack a lot, that trade-off of implementation complexity vs. performance may be different.

rgommers commented 3 years ago

As a meta-comment: I do not think it's necessarily the right goal to end up with essentially a copy of __cuda_array_interface__ on the C side of DLPack. Given that __cuda_array_interface__ gets away with passing a Python int, it seems like passing a C pointer and managing its lifetime for DLPack should not be necessary (given the CuPy/RAPIDS needs at least).

Other options include:

kkraus14 commented 3 years ago

The trade-off here has some unstated assumption on how often this kind of data interchange is used. If one is gluing together two ML models by exchanging some final or intermediate output once in a while, and the occasional extra sync if both of those models use non-default streams isn't a big deal.

This isn't true in many cases. Take for example a library which uses the default stream asynchronously, which is not uncommon. If I'm forced to synchronize the stream to make the memory able to be used immediately by any other stream, I'm required to synchronize the entire device which prevents me from doing any other work on the device until that synchronization is finished. I'd argue this is a big deal.

As a meta-comment: I do not think it's necessarily the right goal to end up with essentially a copy of __cuda_array_interface__ on the C side of DLPack.

I 100% agree that it shouldn't be a copy, but if we're interested in being an interchange protocol for devices that have an asynchronous execution model and an asynchronous memory model, then we should really support that asynchrony in the interchange protocol.

Given that __cuda_array_interface__ gets away with passing a Python int, it seems like passing a C pointer and managing its lifetime for DLPack should not be necessary (given the CuPy/RAPIDS needs at least).

I personally don't like that we're passing things like memory pointers / streams around as integers in Python when we could be passing some type of object oriented wrapper around that handles the lifetime management instead, but we followed the spec for __array_interface__ to start and then followed suit in adding the stream parameter.

tqchen commented 3 years ago

We are getting back to the argument of S1 style API and S0 style API :)

Although it would be great to dissect the argument a bit. Right now the discussions goes back to fully synchronize vs async handling, and implies that "S0== sync entire device, S1 == async handling". This is not necessarily true.

I want to highlight that S0 does not imply synchronization of the entire device, and is never meant to say so. The decision of synchronization vs async handling can be done by the implementer in either S0 or S1 settings.

Let us discuss how async exchange can be handled in S0:

Note that both W0 and W1 does not require synchronization of entire device and falls under the S0 API. The only difference from S1 is that PushStreamDep is called in the __dlpack__ (when case both producer and consumer stream are alive). This also avoid the overhead of stream lifecycle management problem @kkraus14 mentioned. As we know that common frameworks wants to manage the lifecycle of their own streams. So frameworks are not readily smart enough to handle streams exported from another framework.

In both W0 and W1. producer and consumer streams are being managed by producer and consumer only, without having to worry about the stream lifecycle due to exportation. This is again another simplicity bought by the S0 style handling.

To summarize, I think we all agree that it is ideal to handle exchange asynchronously when possible. Both S0 and S1 style API have mechanism to do so.

kkraus14 commented 3 years ago

When both sides uses default stream, then no sync is needed (this is the simplest setting)

Another case we covered in __cuda_array_interface__ is if the data is already synchronized and therefore can be consumed by any stream without doing any unnecessary synchronization or event waiting. I.E. someone produced something on a stream, but is continuing to do additional work on that stream, we don't want to unnecessarily synchronize the stream or wait on an event in the stream (unless the event could be guaranteed to be inserted onto the stream in the correct place relevant to the memory in question, but that sounds unlikely) and be blocked by unrelated work.

We decided to handle this using None for the case where the data does not require any synchronization, and explicitly disallowed 0 as the default stream is ambiguous between the legacy default stream(1) and the per-thread default stream (2).

W0: in function dlpack(self, consumer_stream), the Producer calls PushStreamDep(producer_stream, consumer_stream)

What would C/C++ libraries do? How do a producer C/C++ library and a consumer C/C++ library exchange stream information in a standard way? CUDA streams could likely be an implementation detail of the library where they're not exposed to a user developing with the library.

If we want to avoid putting streams, events, etc. into the dlpack_tensor struct then that's perfectly fine, but it really feels like standardizing how the information needs to be exchanged should be part of whatever protocol/standard is defined.

Producer calls PushStreamDep(producer_stream, default_stream), so the data will be visible from the default stream

Unfortunately, calling cudaStreamWaitEvent on the default stream will synchronize the entire device similar to if you used cudaStreamSynchronize (https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html#stream-sync-behavior__default-stream), so this would always synchronize the entire device.

tqchen commented 3 years ago

Thanks @kkraus14 , great points :) trying to answer inline

I.E. someone produced something on a stream, but is continuing to do additional work on that stream, we don't want to unnecessarily synchronize the stream or wait on an event in the stream (unless the event could be guaranteed to be inserted onto the stream in the correct place relevant to the memory in question, but that sounds unlikely) and be blocked by unrelated work.

This certainly can be covered in S0 style interface as well. Because when the user requests default, and it is the default, or when the producer and consumer stream matches each other then no synchronization is needed.

What would C/C++ libraries do? How do a producer C/C++ library and a consumer C/C++ library exchange stream information in a standard way? CUDA streams could likely be an implementation detail of the library where they're not exposed to a user developing with the library.

The same problem will be faced in an S1 style API as well :) There are a few ways to do so. Notably, because S0 style API does not need to handle stream lifecycle management, which simplies the problem. Being able to have a faithful repr of cudaStream_t (that is alive at the time point of calling __dlpack__) is sufficient. In the case of cudaStream_t it seems to be fine to store that as a integer, as the common practice in cuda. The consumer owns and ensures that the consumer_stream is alive when calling __dlpack__, so producer can correctly call PushStreamDep, without having to passing streams and their life management around.

If we want to avoid putting streams, events, etc. into the dlpack_tensor struct then that's perfectly fine, but it really feels like standardizing how the information needs to be exchanged should be part of whatever protocol/standard is defined.

I do not disagree on this part. see also the above comments

Unfortunately, calling cudaStreamWaitEvent on the default stream will synchronize the entire device similar to if you used cudaStreamSynchronize (https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html#stream-sync-behavior__default-stream), so this would always synchronize the entire device.

Based on my read of the spec, this does not imply synchronization (which blocks the host API until the event completes), instead it creates an implicit dependency of other streams on the sync action. In the case of legacy default stream it is totally fine (as that was the desire of the user). It is like placing a global barrier in an async queue, all the async future actions that are queued by both apps need to wait until we pass the barrier.Of course when other streams are passed in then there will be less blocking.

tqchen commented 3 years ago

Thanks to the great discussion so far. One of generic point being bought up is how to pass a pointer object. In the S0 style API, stream is passed as an integer while dl_tensor is being passed as a capsule.

This really depends on the semantics of the pointer.

Note-ably, while it is quite common to exchange ownership of memory(so others can access it later). It is much less common for frameworks to exchange streams. Since most frameworks have an opinion of their own internal async scheduling mechanism build on top of streams. The additional consideration of stream lifecycle management is also an eng burden

tqchen commented 3 years ago

It seems that we are converging, would be happy to discuss more. Thanks @kkraus14 @leofang @szha for great discussions

leofang commented 3 years ago

Sorry I still have concerns @tqchen. I will try to catch up the discussion made in the past few days asap. Thanks!!!

tqchen commented 3 years ago

Thanks @leofang. To help approaching concensus, it would also be great to dissect the discussion.

In most of cases we may not converge on Q1, since engineering is a tradeoff and different people might prefer a tradeoff in different spectrum. The first thing we want to ensure is that Q0 is converged. Which means the design perspectives(e.g. synchronization, design choices, stream lifecycle) are being covered in the current list of choices and their pros/cons being discussed.

leofang commented 3 years ago

I want to make a few comments along the way of catching up the great discussions made across the past week, and then summarize from my perspective in the next post. So forgive me for thinking out loud and sorry for my slowness...😅

  1. @kkraus14 had a concern on stream lifetime. I want to point out that it was specified in CAI v3 to be the Producer's responsibility to make the stream alive during exchange. Since the synchronization, if not opted out, will be done right at the exchange point, it's allowed to be a transient object and will just meet the need. The same applies to @tqchen's S0 style IIUC. So lifetime management shouldn't be an issue in any case. It's only an issue if we let a PyCapsule object floating around between toDlpack() and fromDlpack() are called, but IIUC we are ruling out this possibility.
  2. I too disagree that DLPack has to converge to CAI v3 (i.e. a S1 style protocol), it was never my intention. But we did put a lot of thoughts when revising it so as to ensure it's legit, solid, and yet performant if needed. The point I've been trying to make is to simply make these considerations appreciated and incorporated in the Array API.
  3. When revising CAI v3 there was actually a suggestion to make it a callable, following more or less the promise approach (labelled as S2) suggested by @szha. The main driving factor behind it was that accessing attributes/properties in Python has side effect: For example, if we do a sync or raise an error when accessing __cuda_array_interface__, the same action will be done simply by calling hasattr(x, '__cuda_array_interface__') (same applies to getattr)!!! This is very bad, so with hindsight S2 is not unreasonable to me, and I wanna make sure people are aware of this Python limitation when considering, say, __dlpack__. (We recently are dealing with exactly this bug in CuPy: https://github.com/cupy/cupy/pull/4524.) In fact, S2 can be made equivalent to S1 if such a callable is simply wrapping cudaStreamSynchronize() or cudaStreamWaitEvent().
  4. I disagree, though, with @szha's statement on "the synchronization methods need not couple with the memory" (https://github.com/dmlc/dlpack/issues/57#issuecomment-752817174). Assuming we exclude the cross-architecture transfers (which necessarily involve copies), the whole point of doing synchronizations (or not) is to ensure the memory is ready to be consumed.
  5. stream = None, 0, 1, 2 have their own meanings in CAI v3/CUDA, please refer to my comment in the other PR: https://github.com/data-apis/array-api/pull/106#discussion_r550960683. It's not an arbitrary choice. And yes, streams in CUDA/HIP are unique within the same process.
  6. I didn't appreciate that S0 asks the Producer to sync on the Consumer's stream, so thanks @tqchen for the nice summary in https://github.com/dmlc/dlpack/issues/57#issuecomment-753220425. As I noted earlier, one of my concerns was to be unable to access the Producer's internal streams. But IIUC such an access is not needed if the Producer is going to sync over the Consumer's stream, as it can make all of its internal streams (of which it has the full knowledge) wait on the provided stream.
szha commented 3 years ago

@leofang thanks for sharing. On 4., in my original comment, I meant to tease apart the memory for the exchange and the the synchronization mechanism that's internal to the producer which is responsible for producing the result on that memory space. This applies when the consumer has access to the commonly addressable memory space but not the producer's internal synchronization mechanism.

Apparently I should have done a better job at describing it. I hope that the clarification makes more sense 😆

rgommers commented 3 years ago

For example, if we do a sync or raise an error when accessing __cuda_array_interface__, the same action will be done simply by calling hasattr(x, '__cuda_array_interface__') (same applies to getattr)!!! This is very bad, so with hindsight S2 is not unreasonable to me, and I wanna make sure people are aware of this Python limitation when considering, say, __dlpack__.

__dlpack__ is a method rather than an attribute, so hasattr should not produce this side effect. The reason for making __array_interface__ an attribute rather than a method is probably lost in the mists of time, but it seems an odd choice to me.

harrism commented 3 years ago

This certainly can be covered in S0 style interface as well. Because when the user requests default, and it is the default, or when the producer and consumer stream matches each other then no synchronization is needed.

I want to share a subtle fact about CUDA streams that has affected us recently in RAPIDS memory manager, specifically related to "when the producer and consumer stream matches" in the above. The fact is that you cannot always depend on A == B meaning that stream A and stream B are actually the same stream. The reason is that cudaStream_t is just a pointer to memory, and like any pointer to memory, it can be reused. So if the producer (to_dlpack) stores the stream A passed to it, and then time passes and the owner of stream A calls cudaStreamDestroy(A), and then later cudaStreamCreate(&B) before calling the consumer (from_dlpack) with stream B, then comparing the stored stream A to the stream B passed to the consumer may result in equality, even though they are not the same stream.

Therefore, while this is true:

"Because when the user requests default, and it is the default, ~or when the producer and consumer stream matches each other~ then no synchronization is needed."

This is not always true:

"Because ~when the user requests default, and it is the default, or~ when the producer and consumer stream matches each other then no synchronization is needed."

(Forgive awkward reuse above, I'm trying to make the importance obvious in context.)

In general it is dangerous to store a stream unless you control its lifetime, but the above is much more subtle than just the typical "use after free". Storing events recorded on a stream is safer, but slightly more expensive.

tqchen commented 3 years ago

Thanks @harrism for the great point .

In the case of S0. The check of stream equivalence happens during exchange(when__dlpack__ is called), and not after exchange. So "when the producer and consumer stream matches each other then no synchronization is needed" is still correct. As both stream are guaranteed to be alive when __dlpack__ is called. Of course we do not guaranteed both to be alive after __dlpack__ call is finished. But necessary sync happens in the __dlpack__ function

In the case of S1. We will need to somehow not only store the stream itself, but also transfers(share) the ownership in the DLTensor, because the use of the stream is deferred until the consuming time, to avoid the use after free problem

leofang commented 3 years ago

Thanks @harrism for joining the discussion. I agree that lifetime is an important consideration, and that in general S0 can be considered safe with regard to this, but @tqchen let's not jump to the conclusion too fast 🙂

I think S1 can also be considered safe, and as for how is an implementation detail. For example, in CuPy there's a concept of current stream, whose lifetime is guaranteed by holding a reference to internally, so if an exchange is done through an S1 style API (ex: CAI v3) on the current stream, it is safe. I think a few other libraries that I am aware of also have a more or less similar mechanism to allow a safe implementation.

For the purpose of reaching a conclusion, it looks like @tqchen has a very strong preference for the S0 style. Am I getting it right that this is the final version of how it's expected to work?

y = consumer.from_dlpack(x_from_producer, consumer_stream)

# internally, from_dlpack calls x_from_producer.__dlpack__(consumer_stream)

If so, I don't have major objection other than pointing out (again) that this leaves the burden of stream management on the users, not on the library implementers.

One minor thing I am picking up from a earlier discussion:

Note-ably, while it is quite common to exchange ownership of memory(so others can access it later). It is much less common for frameworks to exchange streams.

It is not true. You'd be surprised by the need. For example at least both CuPy and Numba have the concept of external streams.

tqchen commented 3 years ago

Thanks @leofang

If so, I don't have major objection other than pointing out (again) that this leaves the burden of stream management on the users, not on the library implementers.

The user facing API will be

y = consumer.from_dlpack(x_from_producer)

And it is consumer.from_dlpack's job to pass in consumer stream into the x_from_producer.__dlpack__(consumer_stream).

It is not true. You'd be surprised by the need. For example at least both CuPy and Numba have the concept of external streams.

I am speaking for my own experience of building MXNet and TVM, along with looking at other frameworks like pytorch and TF. There are certainly frameworks that can support external streams, but such complexity will grow as the framework starts to introduce its own internal async runtime. So my claim is that it is less common(instead of impossible)

harrism commented 3 years ago

Note-ably, while it is quite common to exchange ownership of memory(so others can access it later). It is much less common for frameworks to exchange streams.

It is not true. You'd be surprised by the need. For example at least both CuPy and Numba have the concept of external streams.

Related aside: In RMM we have recently introduced stream wrapper classes, and we decided to have both an owning (rmm::cuda_stream) and a non-owning "view" wrapper (rmm::cuda_stream_view). The former is basically an RAII class for creating streams and simplifying lifetime management, while the latter provides us a strong type for stream parameters in our APIs.

leofang commented 3 years ago

And it is consumer.from_dlpack's job to pass in consumer stream into the x_from_producer.__dlpack__(consumer_stream).

Right, thank you @tqchen. Apparently having not reviewed https://github.com/data-apis/array-api/pull/106 led to my stupid question (though in my defense it's hard to keep track of scattered discussions 😂)

It is not true. You'd be surprised by the need. For example at least both CuPy and Numba have the concept of external streams.

I am speaking for my own experience of building MXNet and TVM, along with looking at other frameworks like pytorch and TF. There are certainly frameworks that can support external streams, but such complexity will grow as the framework starts to introduce its own internal async runtime. So my claim is that it is less common(instead of impossible)

I think for the benefit of the Python Array API, it wouldn't hurt to think of a bigger picture here 🙂 Before ML got popular, there were already tons of people relying on performant libraries to interoperate and do all sorts of scientific computing / data analysis / data processing, etc. Even HPC/HTC people could make use it in their pipelines. In a sense, things are relatively simple on that side of spectrum -- users can be pleased by having such a simple wrapper utility at their disposal (ex: cupy.cuda.ExternalStream or rmm::cuda_stream_view as @harrism kindly shared above).

rgommers commented 3 years ago

I think for the benefit of the Python Array API, it wouldn't hurt to think of a bigger picture here

We tend to not innovate there too much though, it's more synthesis - if most/all libraries have some functionality, then add it. And where there are differences, choose the optimal syntax and semantics. Adding new functionality into it that half the libraries don't support implies an extra implementation burden and hence lowers the chance for adoption, so we should be very careful when doing that. I'd rather see this added in the future if more libraries support/need it, which can be done because DLPack is versioned.

Also, there's other benefits to the simplicity of only passing integer stream number. For example, with my NumPy maintainer hat on: I'm going to have to propose adding DLPack to NumPy soon, and because on CPU-only libraries like NumPy it doesn't add extra functionality over the buffer protocol or __array_interface__, there's going to be a bit of a discussion about "yet another protocol". DLPack's simplicity helps. Being able to take dlpack.h as it is (now ~170 LoC) and vendoring it is nice, adding more CUDA-specific functionality into makes the conversation a tiny bit harder again.

leofang commented 3 years ago

Thanks, @rgommers. I agree with most of your comment.

I wasn't advocating that we should add a (non-owning) stream-view wrapper or anything like that to the Array API. It was simply a digression in response to @tqchen's earlier comment ("it is much less common for frameworks to exchange streams"), with a few non-ML counterexamples.

I completely understand the challenges ahead to push for agreements and timely adoptions -- which you guys have done a fantastic work 👍 -- but I would also like to see that during any discussions different opinions from non-ML communities can be heard and considered. We are not thinking about a "Python Machine Learning API", after all, and if as you said you're gonna pushing this to NumPy, there could be more concerns as most NumPy users are not ML people. By including these considerations in the early stage could help eliminate some potential pushbacks.

Since you brought up, I'd like to leave a quick note:

..., which can be done because DLPack is versioned.

For now DLPack's version is not reliable. v0.3 is already out, but the header's version is unfortunately still on 0.2: https://github.com/dmlc/dlpack/blob/1b794e7088b754f4b0398d211452de0ab28312b5/include/dlpack/dlpack.h#L16 For this reason, we had to manually pin its git commit in CuPy: https://github.com/cupy/cupy/blob/d69d70d5108ad1241fd4e77f4819f57d286e90ef/cupy/core/dlpack.pyx#L60-L62 Though it should be a quick fix.

tqchen commented 3 years ago

@leofang Thanks for pointing out the versioning issue. I have updated the version macro to match the file. Moving forward, we would certainly make sure version tag matches the content accordingly.

leofang commented 3 years ago

Thanks for fixing it quickly, @tqchen!

kkraus14 commented 3 years ago

Based on where we are from discussion thus far, it sounds like regardless of whether we have an S0 or S1 API, many frameworks won't be able to hand a CUDA Stream which has lifetime management handled outside of the framework. If this is the case and we want to avoid synchronizing the stream in this exchange, then we'd need to use something like an Event to expose a stream dependency properly (example detailed here: https://github.com/dmlc/dlpack/issues/57#issuecomment-753696812).

If we want to go the route of specifying a destination stream, is it the responsibility of every framework that uses dlpack to also implement the PushStreamDep function or would this be functionality added in dlpack that every framework could inherit?

tqchen commented 3 years ago

@kkraus14 I believe a common util might be useful, although it should be simple enough for framework to just have their own impl

leofang commented 3 years ago

Can someone remind me why we wanna sync over the Consumer's stream, not the Producer's? Isn't the whole point of synchronizations to wait until the Producer is ready to handle out the data? Did I remember it wrong?

tqchen commented 3 years ago

@leofang The "sync" (PushDep) happens asynchrously, and marks a dependency in the queue, which won't block the producer or consumer, at least in the context of CUDA

leofang commented 3 years ago

Thanks for quick reply and for the quotations over sync, @tqchen! But if the Producer does not expose a stream or event handle (in CUDA), how does the Consumer establish such a stream order (the dependency you referred to)?