Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.36k stars 1.35k forks source link

`2.5.0` has an issue accessing memory illegally during backward #812

Open seungduk-yanolja opened 9 months ago

seungduk-yanolja commented 9 months ago

https://github.com/Dao-AILab/flash-attention/issues/338 The same issue is reproduced.

Traceback (most recent call last):
  File "/home/user/miniconda3/envs/axo/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/user/miniconda3/envs/axo/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/user/apps/axolotl/src/axolotl/cli/train.py", line 59, in <module>
    fire.Fire(do_cli)
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/user/apps/axolotl/src/axolotl/cli/train.py", line 35, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
  File "/home/user/apps/axolotl/src/axolotl/cli/train.py", line 55, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
  File "/home/user/apps/axolotl/src/axolotl/train.py", line 167, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/transformers/trainer.py", line 1561, in train
    return inner_training_loop(
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/transformers/trainer.py", line 1893, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/transformers/trainer.py", line 2822, in training_step
    self.accelerator.backward(loss)
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/accelerate/accelerator.py", line 1964, in backward
    loss.backward(**kwargs)
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 318, in backward
    _flash_attn_varlen_backward(
  File "/home/user/miniconda3/envs/axo/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 179, in _flash_attn_varlen_backward
    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I tested various combinations such as transformers==4.37.2 and transformers-4.38.0.dev0 with flash-attn versions 2.5.2, 2.5.0, 2.4.3.post1, and 2.4.2.2.4.3.post1 and 2.4.2 do not have the issue, while 2.5.2 and 2.5.0 had the issue.For the same dataset, it consistently occurred at the 46th global step. No other datasets were tested yet.

Environment
PyTorch version: 2.1.2
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.28.1
Libc version: glibc-2.35

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-92-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 535.129.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             128
On-line CPU(s) list:                0-127
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8462Y+
CPU family:                         6
Model:                              143
Thread(s) per core:                 2
Core(s) per socket:                 32
Socket(s):                          2
Stepping:                           8
Frequency boost:                    enabled
CPU max MHz:                        2801.0000
CPU min MHz:                        800.0000
BogoMIPS:                           5600.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          3 MiB (64 instances)
L1i cache:                          2 MiB (64 instances)
L2 cache:                           128 MiB (64 instances)
L3 cache:                           120 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-31,64-95
NUMA node1 CPU(s):                  32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] torch==2.1.2
[pip3] torchaudio==2.1.2
[pip3] torchvision==0.16.2
[pip3] triton==2.1.0
[conda] blas                      1.0                         mkl
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344
[conda] mkl-service               2.4.0           py310h5eee18b_1
[conda] mkl_fft                   1.3.8           py310h5eee18b_0
[conda] mkl_random                1.2.4           py310hdb19cb5_0
[conda] numpy                     1.26.3          py310h5f9d8c6_0
[conda] numpy-base                1.26.3          py310hb5e798b_0
[conda] pytorch                   2.1.2           py3.10_cuda12.1_cudnn8.9.2_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.1.2               py310_cu121    pytorch
[conda] torchtriton               2.1.0                     py310    pytorch
[conda] torchvision               0.16.2              py310_cu121    pytorch
tridao commented 9 months ago

Can you save the tensors being passed to flash_attn_cuda.varlen_bwd and send them to me? Otherwise it would be very hard to debug?

And can you print out the value of cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, and the shape of q, k, v?

tleyden commented 9 months ago

I'm hitting a similar crash:


Traceback (most recent call last):
  File "//scripts/run_sft.py", line 216, in <module>
    main()
  File "//scripts/run_sft.py", line 161, in main
    train_result = trainer.train()
  File "/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 280, in train
    output = super().train(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2735, in training_step
    self.accelerator.backward(loss)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 1958, in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 167, in backward
    self.engine.backward(loss, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1936, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/opt/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2093, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 157, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 530, in backward
    dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
RuntimeError: CUDA error: an illegal memory access was encountered
Additional native stack trace ``` Exception raised from c10_cuda_check_implementation at /opt/conda/conda-bld/pytorch_1679586020379/work/c10/cuda/CUDAException.cpp:44 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fc6100934d7 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fc61005d36b in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so) frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fc648054fa8 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10_cuda.so) frame #3: + 0x137bb (0x7fc6480257bb in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10_cuda.so) frame #4: + 0x22d80 (0x7fc648034d80 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10_cuda.so) frame #5: + 0x4cd116 (0x7fc5aa489116 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_python.so) frame #6: + 0x3ee77 (0x7fc610078e77 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so) frame #7: c10::TensorImpl::~TensorImpl() + 0x1be (0x7fc61007169e in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so) frame #8: c10::TensorImpl::~TensorImpl() + 0x9 (0x7fc6100717b9 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so) frame #9: + 0x7526c8 (0x7fc5aa70e6c8 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_python.so) frame #10: THPVariable_subclass_dealloc(_object*) + 0x305 (0x7fc5aa70ea55 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_python.so) frame #11: + 0x12f83f (0x5615851d883f in /opt/conda/bin/python) frame #12: + 0x12f8e6 (0x5615851d88e6 in /opt/conda/bin/python) frame #13: + 0x15536d (0x5615851fe36d in /opt/conda/bin/python) frame #14: + 0x13c848 (0x5615851e5848 in /opt/conda/bin/python) frame #15: + 0x14e1bf (0x5615851f71bf in /opt/conda/bin/python) frame #16: + 0x14e2a1 (0x5615851f72a1 in /opt/conda/bin/python) frame #17: + 0x14e2a1 (0x5615851f72a1 in /opt/conda/bin/python) frame #18: + 0x14e2a1 (0x5615851f72a1 in /opt/conda/bin/python) frame #19: + 0x14e2a1 (0x5615851f72a1 in /opt/conda/bin/python) frame #20: + 0x123659 (0x5615851cc659 in /opt/conda/bin/python) frame #21: PyDict_SetItemString + 0x4a (0x5615851cfe1a in /opt/conda/bin/python) frame #22: + 0x21a1fa (0x5615852c31fa in /opt/conda/bin/python) frame #23: Py_FinalizeEx + 0x16f (0x5615852c251f in /opt/conda/bin/python) frame #24: Py_RunMain + 0x108 (0x5615852b4608 in /opt/conda/bin/python) frame #25: Py_BytesMain + 0x39 (0x561585282709 in /opt/conda/bin/python) frame #26: __libc_start_main + 0xf3 (0x7fc679e71083 in /usr/lib/x86_64-linux-gnu/libc.so.6) frame #27: + 0x1d9611 (0x561585282611 in /opt/conda/bin/python) ```
Environment * 4 x A10 GPUS * Driver versions: `NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2` * Flash attn version: `flash-attn 2.5.2` * Deep speed version: `deepspeed 0.12.2` * Pytorch version: `torch 2.0.0` * Linux version: `Linux 2ddba115bd4f 5.10.209-198.812.amzn2.x86_64 #1 SMP Tue Jan 30 20:59:52 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux`

@tridao let me know if I can collect any more info to help - if you have a code snippet or sample that would be useful. In my case flash attn is being invoked from deepspeed.

tridao commented 9 months ago

Can you save the tensors being passed to flash_attn_cuda.varlen_bwd and send them to me? Otherwise it would be very hard to debug?

And can you print out the value of cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, and the shape of q, k, v?

seungduk-yanolja commented 6 months ago

I tried to save the tensors when it happened but it was challenging and I failed. One thing I can tell is that it happens when I modify the gradients using the hook as explained here: https://huggingface.co/yanolja/KoSOLAR-10.7B-v0.2#technical-deep-dive

tridao commented 6 months ago

Would be hard for me to debug if I can't reproduce it. You can do try catch to hopefully save the tensors.

leasunhy commented 6 months ago

Same here. Please kindly use the tensors here to reproduce the issue.

sfc-gh-bzhai commented 1 month ago

Same here, and the weird part is if I using the datasets with load_dataset, it can run without issue, but if I use local dataset which loaded by load_from_disk, it will trigger this issue.