openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.55k stars 394 forks source link

Marking a CUDA custom call as command buffer-compatible has no effect #14889

Closed andportnoy closed 1 month ago

andportnoy commented 1 month ago

This seems to happen because this piece of logic only looks at registrations for generic platform gpu: https://github.com/openxla/xla/blob/5e9ce9715ca3c43f8bf81c7ebdcf1c8642fa5a5d/xla/service/gpu/command_buffer_scheduling.cc#L162-L165 Hence custom calls registered for CUDA are not taken into account.

@ezhulenev has suggested offline that the fix might be to do platform name canonicalization more thoroughly.

A quick way to repro is to modify the JAX cuda_custom_call test as follows:

diff --git a/docs/cuda_custom_call/cuda_custom_call_test.py b/docs/cuda_custom_call/cuda_custom_call_test.py
index 563462feb..0e3a5453b 100644
--- a/docs/cuda_custom_call/cuda_custom_call_test.py
+++ b/docs/cuda_custom_call/cuda_custom_call_test.py
@@ -72,7 +72,8 @@ library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
 xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
                                        fn=ffi.pycapsule(library.FooFwd),
                                        platform=XLA_PLATFORM,
-                                       api_version=XLA_CUSTOM_CALL_API_VERSION)
+                                       api_version=XLA_CUSTOM_CALL_API_VERSION,
+                                       traits=1)

 # our forward primitive will also return the intermediate output b+1
@@ -111,7 +112,8 @@ mlir.register_lowering(foo_fwd_p, _foo_fwd_lowering, platform=JAX_PLATFORM)
 xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
                                        fn=ffi.pycapsule(library.FooBwd),
                                        platform=XLA_PLATFORM,
-                                       api_version=XLA_CUSTOM_CALL_API_VERSION)
+                                       api_version=XLA_CUSTOM_CALL_API_VERSION,
+                                       traits=1)

then run the following script (you'll need Nsight Systems CLI installed), which will show whether each kernel was launched as part of a CUDA graph or not:

XLA_FLAGS=--xla_gpu_graph_min_graph_size=1 nsys profile --cuda-graph-trace=node -o custom-call-graph --force-overwrite=true python cuda_custom_call_test.py
nsys stats -r cuda_kern_exec_trace --force-export=true custom-call-graph.nsys-rep
phu0ngng commented 1 month ago

Hi @ezhulenev, I confirm that cudaGraph showed up in the nsys reports with the fixes introduced in #14921 and #15021. We can close this issue. Many thanks.

hawkinsp commented 1 month ago

Closing, per @phu0ngng 's report that this is fixed.