facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.66k stars 614 forks source link

memory_efficient_attention fw produce inconsistent results #1039

Open ShijunK opened 6 months ago

ShijunK commented 6 months ago

❓ Questions and Help

memory_efficient_attention fw produce inconsistent results

not sure what was going on? incorrect built? some specific versions combinations?

for some combinations: xformers torch CUDA GPU CUDA Compute Capacity Status v0.0.20+1dc3d7a(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed v0.0.20+1dc3d7a(built from source) 1.13 11.7 A100 8 Failed v0.0.21+320b5ad(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed v0.0.22+1e065bc(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed v0.0.23+1254a16(built from source) 1.13 11.7 Quadro RTX 6000 7.5 Failed

but passed for some: v0.0.20+1dc3d7a(built from source) 1.13 11.7 RTX A6000 8.6 Passed v0.0.24+f7e46d5(built from source) 2.2 11.8 A100 8 Passed v0.0.24+f7e46d5(built from source) 2.2 11.8 RTX A6000 8.6 Passed v0.0.24+f7e46d5(built from source) 2.2 12.1 RTX A6000 8.6 Passed v0.0.24+f7e46d5(built from source) 2.2 12.1 H100 9 Passed 0.0.22.post7(pip install) 2.1 11.8 A100 8 Passed 0.0.23(pip install) 2.1.1 11.8 A100 8 Passed

Command

pytest test_simple.py -v

To Reproduce

Steps to reproduce the behavior: ( for the combination: v0.0.20+1dc3d7a(built from source) 1.13 11.7 A100 )

  1. git checkout v0.0.20
  2. git submodule update --init --recursive
  3. Install cuda-11.7 locally
  4. Install torch-1.13+cu117 in a venv
  5. python setup.py bdist_wheel
  6. install xformers in venv
  7. pytest test_simple.py -v

test code:

import sys
import pytest
import torch

import xformers.ops
from xformers import info

@pytest.mark.parametrize("batch_size", [(1), (4), (8)])
@pytest.mark.parametrize(
    "seq_len",
    [
        (2**1),
        (2**3),
        (2**6),
        (2**9),
    ],
)
@pytest.mark.parametrize(
    "k_seq_len",
    [
        (2**1),
        (2**3),
        (2**6),
        (2**9),  # 512
    ],
)
@pytest.mark.parametrize("dim_model", [(128)])
@pytest.mark.parametrize(
    "dtype,rtol,atol",
    [
        (torch.float32, 2e-5, 3e-4),
        (torch.float16, 4e-4, 4e-3),
    ],
)
def test_mem_efficient(
    batch_size,
    seq_len,
    k_seq_len,
    dim_model,
    dtype,
    rtol,
    atol,
):
    dropout = 0.0

    device = torch.device("cuda")

    q = torch.randn(
        (batch_size, seq_len, dim_model),
        requires_grad=False,
        device=device,
        dtype=dtype,
    )

    k = v = torch.randn(
        (batch_size, k_seq_len, dim_model),
        requires_grad=False,
        device=device,
        dtype=dtype,
    )

    with torch.no_grad():
        result_a = xformers.ops.memory_efficient_attention(q, k, v, p=dropout, op=(
            xformers.ops.fmha.cutlass.FwOp,
            xformers.ops.fmha.cutlass.BwOp,
        ))

        result_b = xformers.ops.memory_efficient_attention(q, k, v, p=dropout, op=(
            xformers.ops.fmha.cutlass.FwOp,
            xformers.ops.fmha.cutlass.BwOp,
        ))
    is_close = torch.isclose(
        result_a,
        result_b,
        rtol=rtol,
        atol=atol,
    )
    assert torch.all(is_close)

if __name__ == "__main__":
   info.print_info()
   sys.exit(pytest.main(["--color=yes", "-s", "-vv", __file__]))

output:

memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-2-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-2-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-8-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-2-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-2-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-64-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-64-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-2-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-8-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-512-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype0-2e-05-0.0003-128-512-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-64-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-512-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-512-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-2-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-2-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-512-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-8-512-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-2-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-2-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-8-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-8-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-8-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-64-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-64-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-64-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-512-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-64-512-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-2-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-2-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-2-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-8-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-8-4] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-8-8] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-64-1] PASSED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-64-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-64-8] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-512-1] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-512-4] FAILED
memory_efficient_test.py::test_mem_efficient[dtype1-0.0004-0.004-128-512-512-8] FAILED

Expected behavior

Expect xformers mem efficient attention could produce close enough forward results, when executed twice, across torch versions (1.13 and 2.2) and CUDA versions (11.7, 11.8, 12.1), and GPU with different compute capabilities (7.5, 8.0, 8.6, 9.0), and different q, k seq length, batch size, data types.

Environment

Please copy and paste the output from the environment collection script from PyTorch (or fill out the checklist below manually).

You can run the script with:

# For security purposes, please check the contents of collect_env.py before running it.
python -m torch.utils.collect_env

Additional context

one failed environment:

xFormers 0.0.20+1dc3d7a.d20240311
memory_efficient_attention.cutlassF:               available
memory_efficient_attention.cutlassB:               available
memory_efficient_attention.flshattF:               available
memory_efficient_attention.flshattB:               available
memory_efficient_attention.smallkF:                available
memory_efficient_attention.smallkB:                available
memory_efficient_attention.tritonflashattF:        unavailable
memory_efficient_attention.tritonflashattB:        unavailable
indexing.scaled_index_addF:                        available
indexing.scaled_index_addB:                        available
indexing.index_select:                             available
swiglu.dual_gemm_silu:                             available
swiglu.gemm_fused_operand_sum:                     available
swiglu.fused.p.cpp:                                available
is_triton_available:                               False
is_functorch_available:                            False
pytorch.version:                                   1.13.0+cu117
pytorch.cuda:                                      available
gpu.compute_capability:                            8.0
gpu.name:                                          A100-SXM-80GB
build.info:                                        available
build.cuda_version:                                1107
build.python_version:                              3.10.13
build.torch_version:                               1.13.0+cu117
build.env.TORCH_CUDA_ARCH_LIST:                    None
build.env.XFORMERS_BUILD_TYPE:                     None
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS:        None
build.env.NVCC_FLAGS:                              None
build.env.XFORMERS_PACKAGE_FROM:                   None
build.nvcc_version:                                11.7.64
source.privacy:                                    open source

one success environment:

xFormers 0.0.23+cu118
memory_efficient_attention.cutlassF:               available
memory_efficient_attention.cutlassB:               available
memory_efficient_attention.decoderF:               available
memory_efficient_attention.flshattF@v2.3.6:        available
memory_efficient_attention.flshattB@v2.3.6:        available
memory_efficient_attention.smallkF:                available
memory_efficient_attention.smallkB:                available
memory_efficient_attention.tritonflashattF:        unavailable
memory_efficient_attention.tritonflashattB:        unavailable
memory_efficient_attention.triton_splitKF:         available
indexing.scaled_index_addF:                        available
indexing.scaled_index_addB:                        available
indexing.index_select:                             available
swiglu.dual_gemm_silu:                             available
swiglu.gemm_fused_operand_sum:                     available
swiglu.fused.p.cpp:                                available
is_triton_available:                               True
pytorch.version:                                   2.1.1+cu118
pytorch.cuda:                                      available
gpu.compute_capability:                            8.0
gpu.name:                                          A100-SXM-80GB
build.info:                                        available
build.cuda_version:                                1108
build.python_version:                              3.10.13
build.torch_version:                               2.1.1+cu118
build.env.TORCH_CUDA_ARCH_LIST:                    5.0+PTX 6.0 6.1 7.0 7.5 8.0+PTX 9.0
build.env.XFORMERS_BUILD_TYPE:                     Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS:        None
build.env.NVCC_FLAGS:                              None
build.env.XFORMERS_PACKAGE_FROM:                   wheel-v0.0.23
build.nvcc_version:                                11.8.89
source.privacy:                                    open source

all versions of xformers (v0.0.20+1dc3d7a, v0.0.21+320b5ad, v0.0.22+1e065bc, v0.0.23+1254a16, v0.0.24+f7e46d5) are built from source, except 0.0.22.post7 and 0.0.23

danthe3rd commented 6 months ago

Hi, If you want deterministic (reproducible) results, you need to enable it in PyTorch: https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html

ShijunK commented 1 month ago

@danthe3rd , forgot to update on this. with further debugging, we found the output from cutlass memory efficient attn produce wrong result (not small diff due to floating point rounding error, or randomness) on old GPU ( compute capability < 8.0) with xformers 0.0.20 and 0.0.24.

we temporarily work around it, by restricting train and inference jobs to use GPU with compute capability >= 8.0 only,