mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
603 stars 45 forks source link

DLPack: Support handling of streams #211

Closed rtabbara closed 10 months ago

rtabbara commented 11 months ago

Motivated by #198 , extending the dlpack interface to support the stream argument. This can be specified by the consumer (PyTorch, JAX) to inform the producer (Dr.Jit) which stream will be used and requests the producer performs any necessary synchronization.

Within the context of CUDA, as specified by the Python array API standard, the stream integer value has a few special values (-1, 0, 1, 2) but otherwise corresponds to the associated CUstream handle. In the latter case, rather than using jit_sync_thread, a new operation jit_cuda_sync_stream has been added in drjit-core that uses CUDA event synchronization to avoid blocking the CPU thread that would otherwise occur if we used sync_thread.

Some preliminary testing was performed in PyTorch by explicitly creating non-default streams. i.e.

s = torch.cuda.Stream()
with torch.cuda.stream(s):
   ...

just to primarily confirm that the new interface does indeed handle when a CUstream is provided

wjakob commented 10 months ago

A few more minor comments -- please go a head and merge the PR after those are addressed.