NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.38k stars 907 forks source link

[QST] IMA when attempting to update PyTorch to 3.2.1 #1138

Closed drisspg closed 9 months ago

drisspg commented 11 months ago

Summary

This PR https://github.com/pytorch/pytorch/pull/108070 updates the pin in PyTorch from 3.1 to 3.2.1. We are currently failing for FlashAttention tests when doing this update.

The update is causing the kernel to IMA.

A minimal repro for this is:

import torch
from torch.nn.functional import scaled_dot_product_attention

# Changing seq_len from 129 to 128 causing IMA to stop, likely because we are doing only 1 iteration 
query = torch.randn(1, 1, 129, 8, device="cuda", dtype=torch.bfloat16) # (batch, num_heads, seq_len, embed_dim)
key = torch.randn(1, 1, 129, 8, device="cuda", dtype=torch.bfloat16) # (batch, num_heads, seq_len, embed_dim)
value = torch.randn(1, 1, 129, 8,device="cuda", dtype=torch.bfloat16) # (batch, num_heads, seq_len, embed_dim)

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    scaled_dot_product_attention(query, key, value)

Using compute-sanitizer --tool memcheck python cutlass_repro.py produces 64 errors of the form:

========= Invalid __global__ read of size 16 bytes
=========     at 0x2bb0 in void pytorch_flash::flash_fwd_kernel<pytorch_flash::Flash_fwd_kernel_traits<(int)32, (int)128, (int)128, (int)4, (bool)0, (bool)0, cutlass::bfloat16_t, pytorch_flash::Flash_kernel_traits<(int)32, (int)128, (int)128, (int)4, cutlass::bfloat16_t>>, (bool)0, (bool)0, (bool)0, (bool)0, (bool)0>(pytorch_flash::Flash_fwd_params)
=========     by thread (28,0,0) in block (1,0,0)
=========     Address 0x7f6f25e00a70 is out of bounds
=========     and is 8,336,181,105 bytes after the nearest allocation at 0x7f6d35000100 of size 512 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame: [0x304fd2]
=========                in /lib64/libcuda.so.1

Full Repro Steps

FlashAttention requires A100 or newer to run ( I have validated on both A100 and H100)

Following the setup instructions for building PyTorch from source here: https://github.com/pytorch/pytorch#from-source

Before building from source checkout the above PR: Can use Githubs CLI tool to do this. Installing gh conda install gh --channel conda-forge Then checkout the PR gh pr checkout 108070

For much faster builds you can use these env variables to turn off parts of the build that don't matter for this Repro:

 #!/bin/bash

export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
export USE_DISTRIBUTED=1
export BUILD_TEST=0
export BUILD_CAFFE2_OPS=0
export USE_FBGEMM=0
export BUILD_CAFFE2=0
export USE_OPENCV=0
export USE_QNNPACK=0
export USE_XNNPACK=0
export DEBUG=0
export USE_KINETO=1
export USE_CUDA=1
export TORCH_SHOW_CPP_STACKTRACES=1
export USE_GOLD_LINKER=1
export USE_NCCL=0
export WERROR=1
# export TORCH_CUDA_ARCH_LIST=9.0
kadeng commented 11 months ago

Just as additonal pieces of information: The issue seems to be caused by some differences between Cutlass 3.2.0 and 3.2.1, since the problem is not present in v3.2.0.

IonThruster commented 11 months ago

I tried to repro this issue locally, but I haven't been able to so far unfortunately :

(pytorch_conda) ~/pytorch/ima_bug $ cat ima_bug.py
import torch
from torch.nn.functional import scaled_dot_product_attention 
print(torch.__version__)
# Changing seq_len from 129 to 128 causing IMA to stop, likely because we are doing only 1 iteration
query = torch.randn(1, 1, 129, 8, device="cuda", dtype=torch.bfloat16) # (batch, num_heads, seq_len, embed_dim)
key = torch.randn(1, 1, 129, 8, device="cuda", dtype=torch.bfloat16) # (batch, num_heads, seq_len, embed_dim)
value = torch.randn(1, 1, 129, 8,device="cuda", dtype=torch.bfloat16) # (batch, num_heads, seq_len, embed_dim)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    scaled_dot_product_attention(query, key, value)

(pytorch_conda) ~/pytorch/ima_bug $ compute-sanitizer --tool memcheck python3 ima_bug.py
========= COMPUTE-SANITIZER
2.2.0a0+gitc8610e6
========= ERROR SUMMARY: 0 errors

From what it looks like - I do have your latest commit too, so not exactly sure what's the difference in my setup.
@drisspg could you confirm if this is still an issue ? (I used Cuda Toolkit Version 12.2.2)

kadeng commented 11 months ago

Assuming you're building Pytorcj from source, did you do something like this?

cd pytorch/third_party/cutlass git checkout v3.2.1 cd ../.. python setup.py develop

IonThruster commented 11 months ago

I ran the below before running setup.py :

gh pr checkout 108070

And as mentioned above, the torch version shows up as :

2.2.0a0+gitc8610e6

Which looks like a commit hash from the PR. I didn't explicitly checkout out cutlass versions.

kadeng commented 11 months ago

Cutlass is pulled in as a git submodule below third_party/cutlass. I am not sure whether the gh tool updates submodules (a plain git checkout does not).

So the safest bet would be to follow the steps I listed above and make sure cutlass is definitely at v3.2.1.

I have encountered the issue on CUDA 12.0, not 12.2, btw.

drisspg commented 11 months ago

I was also running with cuda-toolkit 12.1. I will try to repro today with 12.2. As well, I suggested Gh checkout for fewest commands but indeed the only thing different about that branch and main is that I updated the cutlass submodule to 3.2.1

drisspg commented 11 months ago

I just updated to cuda-toolkit 12.2 and I am still reproducing the IMA w/ compute sanitizer. For more information my env is

Collecting environment information...
PyTorch version: 2.2.0a0+git178268d
Is debug build: False
CUDA used to build PyTorch: 12.2
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.4.1 20230605 (Red Hat 11.4.1-2)
Clang version: 16.0.6 (Red Hat 16.0.6-1.el9)
CMake version: version 3.26.4
Libc version: glibc-2.34

Python version: 3.10.12 (main, Jul  5 2023, 18:54:27) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.0-0_fbk9_zion_11322_gb0aa76a79d7d-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100
GPU 1: NVIDIA H100
GPU 2: NVIDIA H100
GPU 3: NVIDIA H100
GPU 4: NVIDIA H100
GPU 5: NVIDIA H100
GPU 6: NVIDIA H100
GPU 7: NVIDIA H100

Nvidia driver version: 525.105.17
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.8.8.0
/usr/lib64/libcudnn_adv_infer.so.8.8.0
/usr/lib64/libcudnn_adv_train.so.8.8.0
/usr/lib64/libcudnn_cnn_infer.so.8.8.0
/usr/lib64/libcudnn_cnn_train.so.8.8.0
/usr/lib64/libcudnn_ops_infer.so.8.8.0
/usr/lib64/libcudnn_ops_train.so.8.8.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   52 bits physical, 57 bits virtual
Byte Order:                      Little Endian
CPU(s):                          384
On-line CPU(s) list:             0-383
Vendor ID:                       AuthenticAMD
Model name:                      AMD EPYC 9654 96-Core Processor
CPU family:                      25
Model:                           17
Thread(s) per core:              2
Core(s) per socket:              96
Socket(s):                       2
Stepping:                        1
Frequency boost:                 enabled
CPU max MHz:                     2400.0000
CPU min MHz:                     1500.0000
BogoMIPS:                        4792.83
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization:                  AMD-V
L1d cache:                       6 MiB (192 instances)
L1i cache:                       6 MiB (192 instances)
L2 cache:                        192 MiB (192 instances)
L3 cache:                        768 MiB (24 instances)
NUMA node(s):                    2
NUMA node0 CPU(s):               0-95,192-287
NUMA node1 CPU(s):               96-191,288-383
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Vulnerable, IBPB: conditional, IBRS_FW, STIBP: always-on, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.6.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] optree==0.9.1
[pip3] torch==2.2.0a0+git178268d
[pip3] triton==2.1.0
[conda] numpy                     1.24.3                   pypi_0    pypi
[conda] optree                    0.9.1                    pypi_0    pypi
[conda] torch                     2.2.0a0+git178268d           dev_0    <develop>
[conda] triton                    2.1.0                    pypi_0    pypi
kadeng commented 11 months ago

Just in case it proves relevant: The system I use is also using H100's

drisspg commented 11 months ago

I have reproed on A100 and CI/CD is failing for A100 machines

IonThruster commented 11 months ago

I can confirm that I can reproduce the issue at my end as well - gh pr checkout wasn't bringing in the right version of CUTLASS earlier. We will keep investigating and keep you posted.

IonThruster commented 11 months ago

@drisspg / @kadeng : can you try replacing this line i.e :

  DerivedType operator+(Index const& i) const { return {ptr_ + i / ElementsPerStoredItem}; }

with :

  DerivedType operator+(Index const& i) const { return {ptr_ + i}; }

And see if that fixes the issue on your side. We are redesigning this portion anyway in an upcoming release - so if this change works for you - I'll submit a PR for it.

drisspg commented 11 months ago

@IonThruster Yup, this appears to have solved it locally for me!

IonThruster commented 11 months ago

Thanks for the update, the issue is fixed as part of branch release/3.2.x.

@hwu36 - could you please tag it when possible.

@drisspg / @kadeng - do you also have a list of recommended tests which we can use - to ensure better coverage for future releases ?

drisspg commented 11 months ago

So the test that caught this can be found here: https://github.com/pytorch/pytorch/blob/17b732eb0431be4b1b8df3d163338fb31b086d00/test/test_transformers.py#L2438 that being said, I would not expect cutlass to add this test within your unit tests.

I think if we wanted to include FlashAttention tests we would create a CPP harness that runs the ops found here: https://github.com/pytorch/pytorch/blob/17b732eb0431be4b1b8df3d163338fb31b086d00/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp#L302 with some representative inputs

IonThruster commented 10 months ago

@drisspg : v3.2.2 has been tagged.

github-actions[bot] commented 9 months ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.