NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 309 forks source link

[PyTorch] Runtime lookup for CUDA Driver API calls in Userbuffers #970

Closed denera closed 3 months ago

denera commented 3 months ago

Description

PR #901 made Userbuffers a permanent part of the TE/PyTorch extension, which necessitated linking to libcuda.so at compile time to support CUDA multicast functionality (e.g. calls to cuMemXYZ and cuMulticastXYZ). This seems to have broken minor version compatibility with older drivers, causing missing symbol errors when loading the extension library.

This PR avoids linking to libcuda.so at compile time by recovering the memory addresses for the necessary CUDA Driver functions from the CUDA Runtime API via cudaGetDriverEntryPoint() instead.

Type of change

Changes

Checklist:

denera commented 3 months ago

/te-ci pytorch

timmoon10 commented 3 months ago

It would be better to use the CUDA driver API in https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/util/cuda_driver.h, but this approach is fine as a quick solution if it helps avoid complicating https://github.com/NVIDIA/TransformerEngine/pull/760.

denera commented 3 months ago

It would be better to use the CUDA driver API in https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/util/cuda_driver.h, but this approach is fine as a quick solution if it helps avoid complicating #760.

I didn’t know we had this in TE already. What I implemented is effectively identical, so I can just replace CUCALL with NVTE_CHECK_CALL_CUDA_DRIVER.

timmoon10 commented 3 months ago

Actually, the implementation in this PR is better than the existing CUDA driver infrastructure. Accessing CUDA driver functions via cudaGetDriverEntryPoint is safer than manually loading symbols from libcuda.so and this PR caches symbols after loading. Ideally, we would modify Userbuffers to use NVTE_CALL_CHECK_CUDA_DRIVER and change its implementation to use cudaGetDriverEntryPoint.

denera commented 3 months ago

In that case let’s merge this PR (after CI comes back clean) and then I’ll do a follow up PR to update the implementation in our cuda_driver.h with this new approach.

ksivaman commented 3 months ago

/te-ci pytorch

cliffwoolley commented 3 months ago

I agree that cudaGetDriverEntryPoint is a better way to accomplish this than directly using dlsym(). I'm especially glad that the reference to .../compat/lib is removed, as that part was actively broken.

But there's yet a further improvement we can make similar to what TensorFlow does -- which lets the call point still look like cuWhatever() instead of call("cuWhatever").

E.g.: https://github.com/tensorflow/tensorflow/blob/r2.11/tensorflow/compiler/xla/stream_executor/cuda/cuda_stub.cc#L37-L51 https://github.com/tensorflow/tensorflow/blob/r2.11/tensorflow/compiler/xla/stream_executor/cuda/cuda_11_2.inc#L3-L9

PyTorch further streamlines this with a few macros for stubbing functions with only a few arguments (1, 2, 3, or 4): https://github.com/pytorch/pytorch/blob/0680e6cd1c5120d458c41544f276e5863f0a8396/aten/src/ATen/cuda/detail/LazyNVRTC.cpp#L159-L171, then only needing to spell it all out for functions with tons of arguments https://github.com/pytorch/pytorch/blob/0680e6cd1c5120d458c41544f276e5863f0a8396/aten/src/ATen/cuda/detail/LazyNVRTC.cpp#L174-L287 .

denera commented 3 months ago

I updated the get_symbol() implementation in te/common/utils/cuda_driver.h to use cudaGetDriverEntryPoint(), and replaced CUCALL in userbuffers with NVTE_CALL_CHECK_CUDA_DRIVER.

@cliffwoolley I like how PyTorch handles these stubs. We should definitely do this in TE in a follow-up PR. Going through NVTE_CALL_CHECK_CUDA works for now just to unblock TE imports on older drivers, but inferring the function signature from the arguments necessitates a lot of explicit casting that produces very verbose and ugly lines of code.

timmoon10 commented 3 months ago

/te-ci pytorch

ksivaman commented 3 months ago

Re-ran 16134360