ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
131 stars 41 forks source link

IFU to v2.0.4 #14

Closed jayz0123 closed 11 months ago

jayz0123 commented 1 year ago

Current Unit Test Result: (PyTorch 2.0.0; ROCm 5.6) 3968 passed, 63 skipped

Current Performance on MI250: (docker pull rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1)   fwd tflops bwd tflops total
fp16 52.16 39.93 42.49
bf16 52.36 30.25 34.21
jayz0123 commented 11 months ago

A new environment variable "FLASH_ATTENTION_INTERNAL_ENABLE_TIME_KERNEL" can switch the output of kernel running time

jayz0123 commented 11 months ago

[BUGs] Previously in older version of FA, we create tensors z and softmax_lse matrix of max sequence lengths with no padding for grouped gemm. But the strides for each batch for the tensors are different. This behaviour will cause wrong result from CK. Fixing it.

fsx950223 commented 11 months ago

Please remove *_hip.hpp