halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.88k stars 1.07k forks source link

Calling Halide PyTorch kernels from multiple threads causes an assertion failure #6286

Open twesterhout opened 3 years ago

twesterhout commented 3 years ago

Consider the following setup. I have written a Halide kernel and compile it AOT for CPU & CUDA. Then, I use Halide to also generate PyTorch wrappers for the kernel. Finally, this kernel is invoked from Python via the generated extension.

CPU kernel works fine at all times. CUDA kernel, however, triggers an assertion error:

/home/halidenightly/build_bot/worker/halide-release_12-x86-64-linux-cmake/halide-source/src/runtime/cuda.cpp:232 Assert failed: context != nullptr

but only if called from a different thread. On other words, something like kernel(input) works, but std::async(std::launch::async, kernel, input) fails.

Is there some per-thread initialization that I need to do manually?

twesterhout commented 3 years ago

It seems that one needs to initialize CUcontext for every CPU thread from which we call Halide kernels. Halide PyTorch wrapper calls cuCtxGetCurrent, but according to the documentation (https://developer.download.nvidia.com/compute/DevZone/docs/html/C/doc/html/group__CUDA__CTX_g8f13165846b73750693640fb3e8380d0.html#g8f13165846b73750693640fb3e8380d0) it may very well return nullptr and CUDA_SUCCESS, meaning that Halide doesn't detect an error, but the assertion later fails. A temporary fix is to create a dummy PyTorch tensor on every CPU thread and device with which we'd like to use Halide, but I feel that this is something Halide should do automatically... Does what I'm saying make sense or am I misunderstanding the issue at hand?