nshepperd / flash_attn_jax

JAX bindings for Flash Attention v2
BSD 3-Clause "New" or "Revised" License
61 stars 0 forks source link

performance benchmark #5

Open cliangyu opened 1 month ago

cliangyu commented 1 month ago

Hi, have you benchmarked fa2 with jax? How much speedup can you get?

nshepperd commented 1 month ago

I did a small comparison benchmark with fa2, naive mha and the pallas mha included with jax 0.4.28. On my desktop with a 3090 and float16:

B     T    H    C    TFlop/s (flash)    TFlop/s (naive)    TFlop/s (pallas)
---  ----  ---  ---  -----------------  -----------------  ------------------
 32  1024    4   32            53.8639            8.24277             51.5742
 32  1024    4   64            57.5379           14.9229              53.5826
 32  1024    4  128            58.8711           26.7008              54.7979
 32  1024    8   32            61.1098            8.6376              51.487
 32  1024    8   64            64.0922           15.6025              52.9087
 32  1024    8  128            65.6265           27.115               51.4726
 32  1024   16   32            63.9256            8.74247             52.2687
 32  1024   16   64            63.1579           15.7274              55.8753
 32  1024   16  128            65.6789           27.7762              57.1428
 32  1024   32   32            63.1234            8.63395             52.5392
 32  1024   32   64            66.5244           15.4328              56.243
 32  1024   32  128            65.5683           28.3536              59.0203