tspeterkim / flash-attention-minimal

Flash Attention in ~100 lines of CUDA (forward pass only)
Apache License 2.0
548 stars 48 forks source link

Add matmul optimize #4

Open Byeong-Chan opened 5 months ago

Byeong-Chan commented 5 months ago

Description

This PR implements a matrix multiplication optimization forward pass for flash attention. (~300 line)

I got these results on my RTX 3060 (sm_80 same or up)

in float minimal

=== profiling manual attention ===
...
Self CPU time total: 834.368ms
Self CUDA time total: 835.075ms

=== profiling minimal flash attention === 
...
Self CPU time total: 668.000us
Self CUDA time total: 687.000us

attn values sanity check: True

in half matmul opt

=== profiling manual attention ===
...
Self CPU time total: 849.544ms
Self CUDA time total: 849.698ms

=== profiling minimal flash attention ===
...
Self CPU time total: 89.000us
Self CUDA time total: 93.000us

attn values sanity check: True

Reference