NVIDIA / cuda-python

CUDA Python Low-level Bindings
https://nvidia.github.io/cuda-python/
Other
809 stars 63 forks source link

Missing conditional node support in Runtime #55

Closed vzhurba01 closed 7 months ago

vzhurba01 commented 7 months ago

New graph node type was added:

but they're not handled in the Runtime shim layer.

galv commented 7 months ago

There is one more bug related to conditional nodes right now, actually, in the driver wrapper, as opposed to the runtime layer.

cuGraphAddNode() actually modifies the conditional.phGraph_out field of the param argument. (i.e., there is an out param).

The trouble is that the code in CUDA_CONDITIONAL_NODE_PARAMS does not realize that the length of the phGraph_out array will be the same value as conditional.size. Instead, it uses an extra variable called self._phGraph_out_length, which never ends up getting set, cuasing the list to also be empty, despite the underlying pointer at phGraph_out being updated.

I did a quick workaround like this to get around this:

modified   cuda/cuda.pyx.in
@@ -8498,6 +8498,7 @@ cdef class CUDA_CONDITIONAL_NODE_PARAMS:                                                                                                                                              
         return self._ptr[0].size                                                                                                                                                                           
     @size.setter                                                                                                                                                                                           
     def size(self, unsigned int size):                                                                                                                                                                     
+        self._phGraph_out_length = size                                                                                                                                                                    
         self._ptr[0].size = size                                                                                                                                                                           
     @property                                                                                                                                                                                              
     def phGraph_out(self):                                                                                                                                                                                 

But probably the best way to do this is to remove the setter for phGraph_out (the user should never set it anyway), and use the size field for its length, while totally deleting the _phGraph_out_length variable.

galv commented 7 months ago

One more thing.

I wrote some code for the original issue you documented here, which you should feel free to adapt:

modified   cuda/_lib/ccudart/utils.pyx.in
@@ -3410,6 +3410,16 @@ cdef cudaError_t toDriverGraphNodeParams(const cudaGraphNodeParams *rtParams, cc                                                                                                     
     elif rtParams[0].type == cudaGraphNodeType.cudaGraphNodeTypeMemFree:                                                                                                                                   
         driverParams[0].type = ccuda.CUgraphNodeType_enum.CU_GRAPH_NODE_TYPE_MEM_FREE                                                                                                                      
         driverParams[0].free.dptr = <ccuda.CUdeviceptr>rtParams[0].free.dptr                                                                                                                               
+    elif rtParams[0].type == cudaGraphNodeType.cudaGraphNodeTypeConditional:                                                                                                                               
+        driverParams[0].type = ccuda.CUgraphNodeType_enum.CU_GRAPH_NODE_TYPE_CONDITIONAL                                                                                                                   
+        driverParams[0].conditional.handle = rtParams[0].conditional.handle                                                                                                                                
+        driverParams[0].conditional.type = <ccuda.CUgraphConditionalNodeType> rtParams[0].conditional.type                                                                                                 
+        driverParams[0].conditional.size = rtParams[0].conditional.size                                                                                                                                    
+        err = <cudaError_t>ccuda._cuCtxGetCurrent(&context)                                                                                                                                                
+        if err != cudaSuccess:                                                                                                                                                                             
+            _setLastError(err)                                                                                                                                                                             
+            return err                                                                                                                                                                                     
+        driverParams[0].conditional.ctx = context                                                                                                                                                          
     else:                                                                                                                                                                                                  
         return cudaErrorInvalidValue                                                                                                                                                                       
     return cudaSuccess                                                                                                                                                                                     
@@ -3418,7 +3428,9 @@ cdef cudaError_t toDriverGraphNodeParams(const cudaGraphNodeParams *rtParams, cc                                                                                                      
 cdef void toCudartGraphNodeOutParams(const ccuda.CUgraphNodeParams *driverParams, cudaGraphNodeParams *rtParams) nogil:                                                                                    
     if driverParams[0].type == ccuda.CUgraphNodeType_enum.CU_GRAPH_NODE_TYPE_MEM_ALLOC:                                                                                                                    
         rtParams[0].alloc.dptr = <void *>driverParams[0].alloc.dptr                                                                                                                                        
-                                                                                                                                                                                                           
+    elif driverParams[0].type == ccuda.CUgraphNodeType_enum.CU_GRAPH_NODE_TYPE_CONDITIONAL:                                                                                                                
+        rtParams[0].conditional._phGraph_out = driverParams[0].conditional._ptr[0].phGraph_out                                                                                                             
+        rtParams[0].conditional._phGraph_out_length = driverParams[0].conditional._phGraph_out_length                                                                                                      

I have not thoroughly tested this by any means, but it is working for me.

vzhurba01 commented 7 months ago

Thank you @galv. I've submitted a fix for the driver as well. Let me know if something is still not working as expected.

Resolved in https://github.com/NVIDIA/cuda-python/commit/dfd31fa609b9c81bcff925824f38531ab3c96706