facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.34k stars 585 forks source link

Sparse Implementation is not triggered #1054

Open ryanliu30 opened 3 months ago

ryanliu30 commented 3 months ago

🐛 Bug

The sparse implementation is never triggered with ScaledDotProduct

To Reproduce

Steps to reproduce the behavior: Following the documentation, I tested

import torch
import argparse
from xformers.components.attention import ScaledDotProduct

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--fraction", type=float, required=True)
    args = parser.parse_args()

    attention = ScaledDotProduct().cuda()

    # FW a random bunch of data
    inputs = torch.rand((16, 1024, 1024), device=torch.device("cuda"))

    # Not a very sparse mask to begin with
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    mask = (torch.rand((1024, 1024)) < args.fraction).cuda()
    att = attention(q=inputs, k=inputs, v=inputs, att_mask=mask)

    torch.cuda.synchronize()
    max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
    print(f"Peak memory use: {max_memory}MB")

and launch

python test.py --fraction 0.9

the output is

Peak memory use: 513MB
python test.py --fraction 0.1

the output is

Peak memory use: 513MB

in two different processes and the numbers reported by the two programs are identical. The improvement reported in the documentation seems to be coming from python/cuda internal optimization

Expected behavior

The one with fraction 0.9 should improve.

Environment

Collecting environment information... PyTorch version: 2.3.0+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64) GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35

Python version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.15.0-78-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.1.105 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA RTX A6000 GPU 1: NVIDIA RTX A6000 GPU 2: NVIDIA RTX A6000 GPU 3: NVIDIA RTX A6000 GPU 4: NVIDIA RTX A6000 GPU 5: NVIDIA RTX A6000 GPU 6: NVIDIA RTX A6000 GPU 7: NVIDIA RTX A6000 GPU 8: NVIDIA RTX A6000 GPU 9: NVIDIA RTX A6000

Nvidia driver version: 530.30.02 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 52 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 104 On-line CPU(s) list: 0-103 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Gold 5320 CPU @ 2.20GHz CPU family: 6 Model: 106 Thread(s) per core: 2 Core(s) per socket: 26 Socket(s): 2 Stepping: 6 CPU max MHz: 3400.0000 CPU min MHz: 800.0000 BogoMIPS: 4400.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid fsrm md_clear pconfig flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 2.4 MiB (52 instances) L1i cache: 1.6 MiB (52 instances) L2 cache: 65 MiB (52 instances) L3 cache: 78 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-25,52-77 NUMA node1 CPU(s): 26-51,78-103 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] torch==2.3.0 [pip3] torch_cluster==1.6.3+pt23cu121 [pip3] torch-ema==0.3 [pip3] torch_geometric==2.5.3 [pip3] torch-runstats==0.2.0 [pip3] torch_scatter==2.1.2+pt23cu121 [pip3] torch_sparse==0.6.18+pt23cu121 [pip3] torch_spline_conv==1.2.2+pt23cu121 [pip3] torchaudio==2.3.0 [pip3] torchvision==0.18.0 [pip3] triton==2.3.0 [conda] numpy 1.26.4 pypi_0 pypi [conda] torch 2.3.0 pypi_0 pypi [conda] torch-cluster 1.6.3+pt23cu121 pypi_0 pypi [conda] torch-ema 0.3 pypi_0 pypi [conda] torch-geometric 2.5.3 pypi_0 pypi [conda] torch-runstats 0.2.0 pypi_0 pypi [conda] torch-scatter 2.1.2+pt23cu121 pypi_0 pypi [conda] torch-sparse 0.6.18+pt23cu121 pypi_0 pypi [conda] torch-spline-conv 1.2.2+pt23cu121 pypi_0 pypi [conda] torchaudio 2.3.0 pypi_0 pypi [conda] torchvision 0.18.0 pypi_0 pypi [conda] triton 2.3.0 pypi_0 pypi

Additional context

I checked relevant code and a straightforward fix seems to be to include an additional if clause in ScaledDotProduct.forward such that it constructs a SparseCS mask instead of an AttentionMask. This fix will not support additive masks so it might not be ideal.

...
        # Convenience, create an attention mask if a tensor was passed
        if att_mask is not None and isinstance(att_mask, torch.Tensor):
            # By default we don't know of the causality, and a check would be expensive
            if att_mask.dtype == torch.bool and att_mask.float().mean() < 0.3:
                att_mask = SparseCS(att_mask, device=att_mask.device)
            else:
                att_mask = (
                    AttentionMask.from_bool(att_mask)
                    if att_mask.dtype == torch.bool
                    else AttentionMask(att_mask, is_causal=False)
                )

This gives

python test.py --fraction 0.9

the output is

Peak memory use: 513MB
python test.py --fraction 0.1

the output is

Peak memory use: 143MB

Further investigation is needed for it to work with float mask.

WayenVan commented 1 month ago

Yes, I just debugged at the position, this is a terrible bug that behaves differently against the document.....

WayenVan commented 1 month ago

This should be fixed ASAP really.

danthe3rd commented 1 month ago

Hi, we don't plan to support ScaledDotProduct further (and actually will deprecate it). This is because it was done before Flash-Attention happened, and now it's not longer competitive, except with absurdly/unusably high levels of sparsity