Open pengzhangzhi opened 1 day ago
This reshape is not what you want
qkv = torch.stack((query_layer, key_layer, value_layer), dim=1).reshape(B, L, 3, H, D)
Please see the docstring or tests here (https://github.com/Dao-AILab/flash-attention/blob/7153673c1a3c7753c38e4c10ef2c98a02be5f778/tests/test_flash_attn.py#L586) to see what input layout the function expects.
I recently benchmarked FlashAttention against PyTorch’s scaled_dot_product_attention using a custom script and observed a significant discrepancy in the output. Below are the details of the issue:
Benchmark Summary:
Steps to Reproduce:
Here is the benchmark code I used for testing:
The discrepancy is consistent across multiple runs, even with deterministic settings (e.g., fixed seeds).
Questions and Notes:
Thank you for your time and for developing such an efficient attention mechanism! Please let me know if you need further details or additional benchmarks.
Best regards,