pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.63k stars 22.24k forks source link

Retry on CUBLAS_STATUS_ALLOC_FAILED #128601

Open david-macleod opened 3 months ago

david-macleod commented 3 months ago

🐛 Describe the bug

This is not technically a bug, so perhaps a feature request would be more appropriate.

We are encountering a scenario when running multiple libtorch threads we will periodically run out of memory which triggers the CUDA caching allocated to flush the free segments and retry the cudaMalloc. This is fine and processing continues.

However sometimes the action that triggers the OOM is a call to cublasCreate (on the creation of a new thread) which presumably has an implicit cudaMalloc call which is not managed by PyTorch's caching allocator.

RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling cublasCreate(handle)

It would be useful for us if this call to cublasCreate triggered the same fallback behaviour of flushing the PyTorch managed cache, and the call to cublasCreate is retried.

Does this seem reasonable?

Versions

Collecting environment information... PyTorch version: 2.3.0a0+ebedce2 Is debug build: False CUDA used to build PyTorch: 12.3 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: 15.0.7 CMake version: version 3.28.1 Libc version: glibc-2.35

Python version: 3.10.13 (main, Sep 5 2023, 06:03:44) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35 Is CUDA available: False CUDA runtime version: 12.3.107 CUDA_MODULE_LOADING set to: N/A GPU models and configuration: GPU 0: NVIDIA A100-PCIE-40GB GPU 1: NVIDIA A100-PCIE-40GB GPU 2: NVIDIA A100-PCIE-40GB GPU 3: NVIDIA A100-PCIE-40GB GPU 4: NVIDIA A100-PCIE-40GB GPU 5: NVIDIA A100-PCIE-40GB GPU 6: NVIDIA A100-PCIE-40GB GPU 7: NVIDIA A100-PCIE-40GB

Nvidia driver version: 535.129.03 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

cc @ptrblck @msaroufim

malfet commented 3 months ago

This sounds reasonable. If you have a PR that accomplishes this, please do not hesitate to submit it @ptrblck do you know if CUBLAS_STATUS_ALLOC_FAILED is always a recoverable error?