Open cliangyu opened 1 month ago
I did a small comparison benchmark with fa2, naive mha and the pallas mha included with jax 0.4.28. On my desktop with a 3090 and float16:
B T H C TFlop/s (flash) TFlop/s (naive) TFlop/s (pallas)
--- ---- --- --- ----------------- ----------------- ------------------
32 1024 4 32 53.8639 8.24277 51.5742
32 1024 4 64 57.5379 14.9229 53.5826
32 1024 4 128 58.8711 26.7008 54.7979
32 1024 8 32 61.1098 8.6376 51.487
32 1024 8 64 64.0922 15.6025 52.9087
32 1024 8 128 65.6265 27.115 51.4726
32 1024 16 32 63.9256 8.74247 52.2687
32 1024 16 64 63.1579 15.7274 55.8753
32 1024 16 128 65.6789 27.7762 57.1428
32 1024 32 32 63.1234 8.63395 52.5392
32 1024 32 64 66.5244 15.4328 56.243
32 1024 32 128 65.5683 28.3536 59.0203
Hi, have you benchmarked fa2 with jax? How much speedup can you get?