Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.03k stars 1.31k forks source link

Result mismatch with headdim=256 bwd #1306

Open zidanehuang001 opened 4 days ago

zidanehuang001 commented 4 days ago

Hello,

I'm trying to test head_dim=256 backward performance on H100, with below modifications, I manager to make it run. However, it reports test mismatch in result comparing. Modifications:

  1. add run_mha_bwd_hdim256 in hopper/flash_bwd_launch_template.h, 64,64 is refering to https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_launch_template.h#L301:
    void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
    constexpr static int Headdim = 256;
    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
        BOOL_SWITCH(params.is_local, Is_local, [&] {
            BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
                BOOL_SWITCH(params.deterministic, Deterministic, [&] {
                    run_flash_bwd<Headdim, 64, 64, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
                });
            });
        });
    });
    }
  2. change head_size limit from 128 to 256 in hopper/flash_api.cpp
  3. uncomment "flash_bwd_hdim256_fp16_sm90.cu" in hopper/setup.py

When running 'hopper/benchmark_attn.py', with 'batch_size = 1, seqlen=8192, nheads = 36', I came across this error which indicates result mismatch:

### mode = 'bwd', batch_size = 1, headdim = 256, seqlen = 8192, causal = False ###
Traceback (most recent call last):
  File "/workspace/code/flash-attention/hopper/benchmark_attn.py", line 294, in <module>
    torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
  File "/usr/local/lib/python3.10/dist-packages/torch/testing/_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

PS: I noticed there is a TODO for headdim256 bwd in hopper/flash_api.cpp, could this lead to the mismatch, anything needed to tuning the number in it? Seems my modification above shouldn't introduce above error.

else if (head_size == 256) {
        // TODO for hdim 256
        if (num_n_blocks <= 40) {
            start_threshold = .24f;
        } else if (std::log2f(num_n_blocksf) <= 8) {
            start_threshold = .33f + std::max(0.f, (std::log2f(num_n_blocksf) - std::log2f(50)) * 0.02971f);
        } else {
            // Just split freely
            start_threshold = .8f;
        }
zidanehuang001 commented 4 days ago

Oh seems like I didn't paste full error log, add it here:

### mode = 'bwd', batch_size = 1, headdim = 256, seqlen = 8192, causal = False ###
Traceback (most recent call last):
  File "/workspace/code/flash-attention/hopper/benchmark_attn.py", line 294, in <module>
    torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
  File "/usr/local/lib/python3.10/dist-packages/torch/testing/_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 5463253 / 75497472 (7.2%)
Greatest absolute difference: 4.13671875 at index (0, 4096, 1, 136) (up to 0.05 allowed)
Greatest relative difference: inf at index (0, 1857, 28, 186) (up to 0.05 allowed)
tridao commented 4 days ago

Please try the tdd branch which supports bwd for hdim up to 256. We'll merge it soon.

zidanehuang001 commented 4 days ago

Thanks! I will try

zidanehuang001 commented 3 days ago

Please try the tdd branch which supports bwd for hdim up to 256. We'll merge it soon.

Thank you for the solution, now I can run hdim256 bwd with aligned output!

One more thing, it looks like hdim256 bwd 9.215 ms is >4x latency to fwd 1.926 ms, it there any room for further improvement?

### mode = 'fwd', batch_size = 1, headdim = 256, seqlen = 8192, causal = False ###
Fav2: 7.572ms, 326.7 TFLOPS
Fav3: 3.476ms, 711.7 TFLOPS
Fav3 varlen: 3.870ms, 639.3 TFLOPS

### mode = 'fwd', batch_size = 1, headdim = 256, seqlen = 8192, causal = True ###
Fav2: 4.137ms, 299.0 TFLOPS
Fav3: 1.926ms, 642.1 TFLOPS
Fav3 varlen: 2.014ms, 614.3 TFLOPS

### mode = 'bwd', batch_size = 1, headdim = 256, seqlen = 8192, causal = False ###
Fav2: 23.420ms, 264.1 TFLOPS
Fav3: 17.491ms, 353.6 TFLOPS
Fav3 varlen: 18.090ms, 341.9 TFLOPS

### mode = 'bwd', batch_size = 1, headdim = 256, seqlen = 8192, causal = True ###
Fav2: 11.967ms, 258.4 TFLOPS
Fav3: 9.215ms, 335.6 TFLOPS
Fav3 varlen: 9.771ms, 316.5 TFLOPS
tridao commented 3 days ago

You're welcome to work on it! We've been best perf with CUDA 12.3.