dmlc / dlpack

common in-memory tensor structure
https://dmlc.github.io/dlpack/latest
Apache License 2.0
890 stars 135 forks source link

Specify synchronization semantics #57

Closed kkraus14 closed 3 years ago

kkraus14 commented 3 years ago

From the CUDA side, nearly all of the DL frameworks and array libraries that support dlpack use CUDA streams and stream order both their computations and memory allocations. In its current form, dlpack doesn't specify any synchronization semantics as well as doesn't have a way to specify information to allow for a producer-consumer pair to exchange the necessary information to continue to stream order computations.

I imagine there's a similar problem in other contexts as well (OpenCL, ROCm, etc.) where maybe it's possible to generalize an approach.

cc @oleksandr-pavlyk

tqchen commented 3 years ago

In the current API, the producer implements producer.__dlpack__, and the producer needs to take charge of calling PushDep to make sure the data is visible to the consumer stream

oleksandr-pavlyk commented 3 years ago

I'm coming to this discussion late, and I have not yet had a chance to internalize all the prior comments.

The way I understand the problem, producer may still have some kernels updating data referenced by producer.__dlpack__. The suggested way to avoid race condition without producer always waiting for those kernels to finish their work is for consumer library to submit to the same stream as the producer (relying on streams executing kernel synchronously).

If my understanding is correct, this does not generalize to SYCL. For consumer to ensure that SYCL kernels submitted by producer have finished updating the memory in the array to be shared, the consumer needs to receive a sycl::event from the producer, and use this event to declare that kernels submitted by consumer to work on the array have dependency of this event, signaling SYCL run-time to block submission of these kernels on this event.

tqchen commented 3 years ago

@oleksandr-pavlyk I am not that familar with SYCL. If SYCL have constructs like fence(barrier) that can depends on events. An analogy to CUDA's void PushStreamDep(cudaStream_t src, cudaStream dst) would be for the producer to submit a fence operation(or a dummy kernel) to consumer_stream(perhaps context or queue in opencl's term), so that later operations to the same consumer_queue will see the effect of the data.

leofang commented 3 years ago

Ah right, thanks for the reminder @tqchen. PushStreamDep seems to be the last missing piece for my understanding, though we should make it clear that it's a shortcut construct stemming from https://github.com/dmlc/dlpack/issues/57#issuecomment-753696812, not a CUDA API 🙂

@oleksandr-pavlyk I think if SYCL's queues and events are analog to CUDA's streams and events, PushStreamDep should work fine? Can these SYCL handles be passed as integer in Python, or they have to be some opaque objects?

oleksandr-pavlyk commented 3 years ago

sycl::queue can be in-order (like CUDA stream), or out-of-order (default).

It is not possible in SYCL to insert a barrier to an out-of-order queue that will make all further submissions to the queue dependent on that event.

It is possible to do it synchronously from the host with producer_queue.wait().

In view of this, an implementation of PushStreamDep in SYCL is only possible if consumer_queue is in-order, i.e. consumer_queue.is_in_order() returns true.

To synchronize in an out-of-order queue requires passing an event, which will be used when submitting new kernels (e.g. see USM-version of GEMM in oneMKL).

Regarding whether SYCL handles can be passed as integers in Python, I am afraid not (like shared pointers can not be passed as integers in Python). The dpctl Python package exposes SyclQueue, SyclEvent, SyclDevice objects among others.

The opaque pointers to SYCL objects can of course be passed around as named PyCapsule objects.

tqchen commented 3 years ago

Thanks @oleksandr-pavlyk . I agree that in such cases we would require exchanges through in order queues provided by the consumer.

To adapt such case, and when the consumer want to operate on out of order queues. The consumer can first create an in-order queue, just for exchange. Create an event X after the exchange finishes, and use X as dependencies for future out of order queues that depend on the data. Likely this won't create additional overhead other than the event depencency tracking.

Such way of adaptation would enable de-couple the issues of the synchronization, lifecycle management from the data structure.

Speaking for myself a framework implementer's PoV. The generalization to arbitrary event dependency chain, although being flexible, creates additional overheads during exchange for the lifecycle management, asynchronization convention etc. Having a layered simpler view(e.g. default in order queue) would help SYCL in the long run, and where we can learn from CUDA (simple and good enough for most cases)

oleksandr-pavlyk commented 3 years ago

@tqchen In SYCL, the consumer would always need to provide an in-order-queue to ensure synchronization (since queue is out of order by default). Not passing any queue would require the producer_queue.wait() to be executed.

There is a SYCL extension proposed to enable one to asynchronously wait on event on the entire queue which may allow to_dlpack_on_stream to handle out-of-order queues as well.

Even so using in-order queue is going to be more performant.

tqchen commented 3 years ago

Try to summarize the current state of discussion as well as pointing out another missing points (per GPU streams). Here is a complete proposal under S0

S0a: Producer Sync based on Stream Specified by Consumer

def __dlpack__(self, stream=None):
    """Get a DLTensor capsule that is visible in stream.

    The producer will take charge to do dependency synchronize to stream.
    If no stream is None, then it defaults to the legacy default stream
    """
    pass

def __dlpack_device__(self) ->Tuple[Int, Int]:
    """Return a tuple pair of device_type, device_id in DLPack convention"""

# consumer code:
def consumer.from_dlpack(producer):
      device = producer.__dlpack_device__()
      consumer_stream = consumer.find_exchange_stream(device)
      dlpack_caps = producer.__dlpack_stream__(consumer_stream)
      return conver_to_consumer_array(dlpack_caps)

Note that most systems associate different stream to a particular device. So if we want to use non-default stream, then knowing which stream to synchronize to is important. As discussed ealier, the main benefit of this style of APIs are:

Given the need to quickly move on, we can also start with a reduced version, and continue more discussions on stream exchange.

S0reduced: Producer Sync based on Stream Specified by Consumer

def __dlpack__(self):
    """Get a DLTensor capsule that is visible in stream.

    The producer will take charge to do dependency synchronize to
    legacy default streamof the corresponding device
    """
    pass

This way the __dlpack_device__ function is not needed, and the behavior is well-defined and consistent among GPU and CPU cases.

oleksandr-pavlyk commented 3 years ago

I think S0a is the way to go.

rgommers commented 3 years ago

Good catch regarding the need for __dlpack_device__. Regarding the consumer code, two minor tweaks:

    # This call is a consumer-internal implementation detail, doesn't need to have a standardized name
    consumer_stream = _find_exchange_stream(device)
    dlpack_caps = producer.__dlpack__(consumer_stream)  # This was a typo, there's no `__dlpack_stream__`

I think S0a is the way to go.

I agree, adding this second method makes sense and it's not too likely there's something that would need changing later.

kkraus14 commented 3 years ago

Seems like we've aligned on the Python side with having __dlpack_device__ protocol and a __dlpack__ protocol which takes a stream as input, where the producer is responsible for making the data it produces safe to use on the given stream.

Presumably we need to solve this same problem for C libraries now. Should I open a new issue or do we want to continue in this issue?

tqchen commented 3 years ago

Thanks @kkraus14 how about we open a new issue? Given that there are no standardized interface for C library exchange, perhaps we could add that as a interface recommendation?