takumi7110 / paper

0 stars 0 forks source link

[2023]FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning #14

Open takumi7110 opened 1 year ago

takumi7110 commented 1 year ago

【背景】 • トランスフォーマーのコンテキスト長のスケーリングは課題であり、アテンション層が長いシーケンスにスケーリングする際のボトルネックとなっている。 • FlashAttentionは、非対称のGPUメモリ階層を利用してメモリの節約とランタイムの高速化を実現しているが、最適化された行列乗算(GEMM)操作に比べて効率が低い。 【目的】 • FlashAttentionの効率を向上させるために、より良い並列化とワークパーティショニングを提案する。 【手法】 • アルゴリズムを調整して非行列乗算FLOPsの数を減らす。 • アテンション計算を並列化し、スレッドブロック間でのオキュパンシーを増やす。 • スレッドブロック内でのワープ間でのワークを分散させ、共有メモリの読み書きを減らす。 【実験方法】 • FlashAttention-2を提案し、FlashAttentionと比較してスピードアップを検証する。 • 異なる設定でのベンチマークを実施し、前方パスと後方パスの理論的な最大スループットの割合を計測する。 • GPTスタイルのモデルのトレーニング速度を測定する。 【実験結果】 • FlashAttention-2はFlashAttentionと比較して約2倍のスピードアップを達成し、前方パスでは理論的な最大スループットの73%、後方パスでは63%に達する。 • GPTスタイルのモデルのトレーニング速度は、A100 GPUあたり225 TFLOPs/sに達する。 【考察】 • FlashAttention-2は、最適化されたGEMM操作とほぼ同等の効率に近づく。 • 提案手法により、スレッドブロックとワープ間のワークパーティショニングが改善され、効率が向上したことが確認された。