pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.19k stars 22.1k forks source link

"RuntimeError: CUDA error: operation not supported" fixed by downgrading toolkit version #135126

Open aorenste opened 1 week ago

aorenste commented 1 week ago

šŸ› Describe the bug

After #134373 I started getting the error "RuntimeError: CUDA error: operation not supported" when trying to run pytorch. Fresh build from source succeeds before #134373 and fails on/after.

Error:

$ python test/inductor/test_triton_kernels.py -k test_triton_kernel_native
ETEST SUITE EARLY TERMINATION due to torch.cuda.synchronize() failure
CUDA error: operation not supported
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
======================================================================
ERROR: test_triton_kernel_native_grad_False_dynamic_False_backend_aot_eager (__main__.KernelTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/aorenste/local/pytorch/torch/testing/_internal/common_utils.py", line 2979, in wrapper
    method(*args, **kwargs)
  File "/home/aorenste/local/pytorch/torch/testing/_internal/common_utils.py", line 532, in instantiated_test
    test(self, **param_kwargs)
  File "/data/users/aorenste/miniconda3/envs/py39/lib/python3.9/unittest/mock.py", line 1336, in patched
    return func(*newargs, **newkeywargs)
  File "/data/users/aorenste/pytorch/test/inductor/test_triton_kernels.py", line 888, in test_triton_kernel_native
    t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad)
RuntimeError: CUDA error: operation not supported
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
To execute this test, run the following from the base repo dir:
    python test/inductor/test_triton_kernels.py KernelTests.test_triton_kernel_native_grad_False_dynamic_False_backend_aot_eager
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 1 test in 0.698s
FAILED (errors=1)

I'm not sure what version of the toolkit I started on - I think it was 12.2. I definitely tried 12.4 and 12.6 and they also failed. Switching to 12.0 succeeded.

Versions

Collecting environment information... PyTorch version: 2.5.0a0+gitae3aa8f Is debug build: False CUDA used to build PyTorch: 12.0 ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64) GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3) Clang version: Could not collect CMake version: version 3.30.2 Libc version: glibc-2.34

Python version: 3.8.19 | packaged by conda-forge | (default, Mar 20 2024, 12:47:35) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-5.19.0-0_fbk12_hardened_11583_g0bef9520ca2b-x86_64-with-glibc2.10 Is CUDA available: True CUDA runtime version: 12.0.140 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA PG509-210 Nvidia driver version: 525.105.17 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: False

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 22 On-line CPU(s) list: 0-21 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8339HC CPU @ 1.80GHz CPU family: 6 Model: 85 Thread(s) per core: 1 Core(s) per socket: 22 Socket(s): 1 Stepping: 11 BogoMIPS: 3591.73 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon $ Virtualization: VT-x Hypervisor vendor: KVM Virtualization type: full L1d cache: 704 KiB (22 instances) L1i cache: 704 KiB (22 instances) L2 cache: 88 MiB (22 instances) L3 cache: 16 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-21 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Vulnerable Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries: [pip3] flake8==6.1.0 [pip3] flake8-bugbear==23.3.23 [pip3] flake8-comprehensions==3.15.0 [pip3] flake8-executable==2.1.3 [pip3] flake8-logging-format==0.9.0 [pip3] flake8-pyi==23.3.1 [pip3] flake8-simplify==0.19.3 [pip3] mypy==1.10.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.24.3 [pip3] optree==0.12.1 [pip3] pytorch-triton==3.0.0+dedb7bdf33 [pip3] torch==2.5.0a0+gitae3aa8f [pip3] torchvision==0.18.0a0 [conda] mkl 2023.2.0 h84fe81f_50496 conda-forge [conda] mkl-include 2024.2.0 ha957f24_665 conda-forge [conda] numpy 1.24.3 pypi_0 pypi [conda] optree 0.12.1 pypi_0 pypi [conda] pytorch-triton 3.0.0+dedb7bdf33 pypi_0 pypi [conda] torch 2.5.0a0+gitae3aa8f dev_0 [conda] torchvision 0.18.0a0 dev_0

cc @malfet @seemethere @ptrblck @msaroufim @ezyang @chauhang @penguinwu

malfet commented 1 week ago

@ptrblck FYI: few of the users run into this problem somehow and saying that downgrading to CUDA-12.0 runtime fixed their problems, even though 525 driver should be compatible with 12.4 runtime, shouldn't it?

drisspg commented 1 week ago

I am also getting this when building with the same driver + 12.3 runtime Driver: 525.105.12 Nvcc Version: Build cuda_12.3.r12.3/compiler.33567101_0

zou3519 commented 1 week ago

Can we yank the PR? multiple PyTorch devs have run into this

ptrblck commented 1 week ago

FYI: few of the users run into this problem somehow and saying that downgrading to CUDA-12.0 runtime fixed their problems, even though 525 driver should be compatible with 12.4 runtime, shouldn't it?

Yes, minor version compatibility is supported for all 12.x build using >=525.60.13. However, note that the "large kernel parameter" change requires >=R530 as described in the announcement blog post:

Note that use of CUDA Toolkit 12.1 and a R530 driver or higher are required to compile, launch, and debug kernels with large kernel parameters. CUDA will issue the CUDA_ERROR_NOT_SUPPORTED error if the launch is attempted on an older driver.