pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.43k stars 636 forks source link

`torchaudio.functional.rnnt_loss` crashes for `logits` with >2**31 elements #3736

Open gpuplz opened 5 months ago

gpuplz commented 5 months ago

🐛 Describe the bug

Repro:

import subprocess
import torch
import torchaudio

b, t, u, v = 33, 256, 64, 4096
print(
    torchaudio.functional.rnnt_loss(
        torch.zeros((b, t, u, v), dtype=torch.float16, device='cuda'),
        torch.ones((b, u-1), dtype=torch.int32, device='cuda'),
        torch.full((b,), t, dtype=torch.int32, device='cuda'),
        torch.full((b,), u-1, dtype=torch.int32, device='cuda'),
    ),
)
subprocess.check_call(['nvidia-smi'])

Observed result:

% python3 repro.py 
Traceback (most recent call last):
  File "repro.py", line 7, in <module>
    torchaudio.functional.rnnt_loss(
  File "/home/me/.local/lib/python3.8/site-packages/torchaudio/functional/functional.py", line 1814, in rnnt_loss
    return costs.mean()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

If you change b from 33 to 32 (→ bringing down logits to 2**31 elements), then it runs successfully, and there is plenty of GPU memory left (which makes me think 33 should be fine too):

% python3 repro.py 
tensor(2498., device='cuda:0', dtype=torch.float16)
Thu Jan 25 16:00:33 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A40                     On  | 00000000:11:00.0 Off |                    0 |
|  0%   37C    P0             155W / 300W |   8520MiB / 46068MiB |     42%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     93893      C   python3                                    8508MiB |
+---------------------------------------------------------------------------------------+

I also tried running this using NVIDIA compute-sanitizer and torch==2.1.2 & torchaudio==2.1.2 that I built myself with CMAKE_CUDA_FLAGS=-lineinfo. It reported an invalid write in the following line: https://github.com/pytorch/audio/blob/v2.1.2/torchaudio/csrc/rnnt/gpu/kernels.h#L86 This seems to make sense since it's writing to gradients[b_t_u_d] when b_t_u_d is a 32-bit signed int which would overflow for logits with >2**31 elements.

% compute-sanitizer python3 repro.py
(...omitted...)
========= COMPUTE-SANITIZER
========= Invalid __global__ write of size 2 bytes
=========     at 0x8a0 in /audio/torchaudio/csrc/rnnt/gpu/kernels.h:86:void torchaudio::rnnt::ComputeGradientsElement<c10::Half, float>(int, int, int, int, int, int, int, T2, const T1 *, const int *, const int *, const int *, const T2 *, const T2 *, const T2 *, T1 *, int, bool)
=========     by thread (192,0,0) in block (0,24,32)
=========     Address 0x7f023a030000 is out of bounds
=========     and is 4194107392 bytes before the nearest allocation at 0x7f0334000000 of size 4429185024 bytes
=========     Device Frame:/audio/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh:339:void torchaudio::rnnt::ComputeGradients<c10::Half, float>(int, int, int, int, T2, const T1 *, const int *, const int *, const int *, const T2 *, const T2 *, const T2 *, T1 *, int, bool) [0x10]
(...omitted...)

Versions

% python3 collect_env.py
Collecting environment information...
PyTorch version: 2.1.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.10 (default, Nov 22 2023, 10:22:35)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A40
Nvidia driver version: 535.129.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 57 bits virtual
CPU(s):                             6
On-line CPU(s) list:                0-5
Thread(s) per core:                 2
Core(s) per socket:                 3
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              106
Model name:                         Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz
Stepping:                           6
CPU MHz:                            2899.998
BogoMIPS:                           5799.99
Virtualization:                     VT-x
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          96 KiB
L1i cache:                          96 KiB
L2 cache:                           12 MiB
L3 cache:                           16 MiB
NUMA node0 CPU(s):                  0-5
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; TSX disabled
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid md_clear arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.1.2
[pip3] torchaudio==2.1.2
[pip3] triton==2.1.0
[conda] Could not collect