facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.67k stars 619 forks source link

Incorrect attention output with SparseCS mask #1124

Open francois-rozet opened 1 month ago

francois-rozet commented 1 month ago

🐛 Bug

The output of scaled_dot_product_attention is wrong when the mask is a SparseCS matrix. In particular the last element of the sequence is incorrect, while others are correct.

To Reproduce

import torch
import xformers.components.attention.core as xf

B, M, N, C, D = 2, 1024, 768, 256, 384

torch.manual_seed(0)

q = torch.randn(B, M, C).cuda()
k = torch.randn(B, N, C).cuda()
v = torch.randn(B, N, D).cuda()

mask = torch.rand((M, N)).cuda() < 0.01
sparse = xf.SparseCS(mask, device=mask.device)

y_torch = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
y_sparse = xf.scaled_dot_product_attention(q, k, v, att_mask=sparse)

error = (y_torch - y_sparse).abs()
print(error[:, :-1].max())  # tensor(1.7881e-06, device='cuda:0')
print(error[:, -1].max())   # tensor(1.6439, device='cuda:0')

Expected behavior

The output of torch.nn.functional.scaled_dot_product_attention and xf.scaled_dot_product_attention should be the same (up to some tolerance).

Environment

Collecting environment information...
PyTorch version: 2.4.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Rocky Linux release 8.10 (Green Obsidian) (x86_64)
GCC version: (GCC) 8.5.0 20210514 (Red Hat 8.5.0-22)
Clang version: Could not collect
CMake version: version 3.26.5
Libc version: glibc-2.28

Python version: 3.10.13 (main, May  1 2024, 20:55:07) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.97.1.fi-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB
Nvidia driver version: 550.90.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              40
On-line CPU(s) list: 0-39
Thread(s) per core:  1
Core(s) per socket:  20
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               85
Model name:          Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
Stepping:            4
CPU MHz:             3603.735
CPU max MHz:         3700.0000
CPU min MHz:         1000.0000
BogoMIPS:            4800.00
L1d cache:           32K
L1i cache:           32K
L2 cache:            1024K
L3 cache:            28160K
NUMA node0 CPU(s):   0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38
NUMA node1 CPU(s):   1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd mba ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req pku ospke md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] torch==2.4.1+cu121
[pip3] torchvision==0.19.1+cu121
[pip3] triton==3.0.0
[conda] Could not collect
francois-rozet commented 1 month ago

After more troubleshooting, it seems that the conversion to a sparse matrix is incorrect.

>>> mask = torch.rand((M, N)).cuda() < 0.01
>>> sparse = xf.SparseCS(mask, device=mask.device)
>>> (sparse.to_dense() != mask).nonzero()
tensor([[   0, 1023,  418],
        [   0, 1023,  573],
        [   0, 1023,  583]], device='cuda:0')

The issue comes from _round_nnz, which drops non-zero elements in the mask when the number of non-zero elements is not a multiple of 4.

https://github.com/facebookresearch/xformers/blob/4a9dd7ec079e0c935db10daa2a1a89fd19cfa231/xformers/sparse/utils.py#L116-L117

Modifying _round_nnz such that it keeps a few zero elements (with value False) instead of dropping non-zero elements, solves the discrepancy between sparse and mask. Note that the following implementation does not require CPU-GPU synchronization.

def monkey_round_nnz(mask, divisible_by=4):
    nnz = torch.count_nonzero(mask)
    cunz = torch.cumsum(~mask.flatten(), dim=0)
    flip = cunz <= (-nnz) % divisible_by

    return torch.logical_or(flip.reshape_as(mask), mask)

xformers.sparse.utils._round_nnz = _round_nnz

However, SparseCSRTensor does not take the values of the mask into account to perform a masked matmul, which results in incorrect attention values.

https://github.com/facebookresearch/xformers/blob/4a9dd7ec079e0c935db10daa2a1a89fd19cfa231/xformers/sparse/csr_tensor.py#L177

Taking the values of mask in _masked_matmul into account solves the issue.

@classmethod
def _masked_matmul(cls, a, b, mask):
    if not (type(a) is torch.Tensor and type(b) is torch.Tensor):
        return NotImplemented
    assert mask.shape[1] == a.shape[1]
    assert mask.shape[2] == b.shape[2]
    values = mask.__values
    row_indices = mask.__row_indices
    row_offsets = mask.__row_offsets
    column_indices = mask.__column_indices
    tansp_info = mask.__transp_info
    out = _csr_ops._sddmm.apply(
        a.contiguous(),
        b.transpose(-2, -1).contiguous(),
        row_indices,
        row_offsets,
        column_indices,
        tansp_info,
    )
    out = torch.where(values, out, float("-inf"))
    return cls._wrap(
        mask.shape,
        out,
        row_indices,
        row_offsets,
        column_indices,
        tansp_info,
    )