NVIDIA / cuda-python

CUDA Python Low-level Bindings
https://nvidia.github.io/cuda-python/
Other
809 stars 63 forks source link

`cudart.cudaSetDevice` before `cudart.cudaGetDevice` produces invalid results #24

Closed wence- closed 1 year ago

wence- commented 1 year ago

Using an environment with:

mamba create -n testing -c nvidia -c conda-forge python=3.9 'cuda-toolkit>=11.7' 'cuda-python>=11.7'
from cuda import cudart

print(cudart.cudaSetDevice(2))
print(cudart.cudaGetDevice())

=>

(<cudaError_t.cudaSuccess: 0>,)
(<cudaError_t.cudaSuccess: 0>, 0)

Expected result: the cudaGetDevice() call should return device 2, not device 0.

The problem appears to be because cudaSetDevice only calls ccudart.utils.lazyInitGlobal, whereas cudaGetDevice calls ccudart.utils.lazyInit (which calls lazyInitDevice(0)).

I think that cudaGetDevice just needs to not call lazyInit (the case of no context being in place is handled by the branch that calls cudaSetDevice(0))

https://github.com/NVIDIA/cuda-python/blob/main/cuda/_lib/ccudart/ccudart.pyx#L1039-L1045

Plausibly a patch like this?

diff --git a/cuda/_lib/ccudart/ccudart.pyx b/cuda/_lib/ccudart/ccudart.pyx
index d42d594..d7f3602 100644
--- a/cuda/_lib/ccudart/ccudart.pyx
+++ b/cuda/_lib/ccudart/ccudart.pyx
@@ -1032,9 +1032,6 @@ cdef cudaError_t _cudaGetDevice(int* device) nogil except ?cudaErrorCallRequires
     cdef cudaError_t err
     cdef ccuda.CUresult err_driver
     cdef ccuda.CUcontext context
-    err = m_global.lazyInit()
-    if err != cudaSuccess:
-        return err

     err_driver = ccuda._cuCtxGetCurrent(&context)
     if err_driver == ccuda.cudaError_enum.CUDA_ERROR_INVALID_CONTEXT or (err_driver == ccuda.cudaError_enum.CUDA_SUCCESS and context == NULL):
@@ -1045,14 +1042,16 @@ cdef cudaError_t _cudaGetDevice(int* device) nogil except ?cudaErrorCallRequires
         err_driver = ccuda._cuCtxGetCurrent(&context)

     if err_driver != ccuda.cudaError_enum.CUDA_SUCCESS:
-        _setLastError(err)
-        return err
+        _setLastError(<cudaError_t>err_driver)
+        return <cudaError_t>err

     found = False
     for deviceOrdinal in range(m_global._numDevices):
         if m_global._driverContext[deviceOrdinal] == context:
             found = True
             break
+    else:
+        return cudaErrorDeviceUninitialized
     device[0] = deviceOrdinal if found else 0
     return cudaSuccess

Note this has two other fixes:

  1. in the case where err_driver != CUDA_SUCCESS actually return the error code
  2. If after all this, we still can't find a context, return cudaErrorDeviceUninitialized (not sure if this is the correct error code)
vzhurba01 commented 1 year ago

Release v11.8.0 resolves this issue by bringing the Runtime context management inline with CUDA Runtime. Thank you!

wence- commented 1 year ago

Thanks @vzhurba01 !