bigcode-project / Megatron-LM

Ongoing research training transformer models at scale
Other
374 stars 49 forks source link

Add flash-attn #41

Closed RaymondLi0 closed 1 year ago

RaymondLi0 commented 1 year ago

Flash-attention, based on https://github.com/NVIDIA/Megatron-LM/pull/267 with support for MQA

RaymondLi0 commented 1 year ago

Some tests with a 1B MQA model (santacoder's config: num_layers 24, num_heads 16, hidden_size 2048), bf16, on 1 A100 gpu.

With flash-attn, this model can be trained with sequences of length up to 8192, and with full-recomputation up to 32768. With normal-attn, we only reach 2048, or 8192 with selective or full recomputation.

Flash-attn is faster, especially for longer sequences: Time-per-iteration for seq-length 2048: flash-attn: 19794.1 VS normal-attn: 22679.3 Time-per-iteration for seq-length 4096: flash-attn: 44040.8 VS normal-attn: 71740 (selective-recomputation) Time-per-iteration for seq-length 8192: flash-attn: 113715.5 VS normal-attn: 256122 (selective-recomputation)

<html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40">

use_flash_attn | seq_len | mbs | gbs | Activation-recomputation | mem_reserved (GB) | iteration_time | TFLOPs -- | -- | -- | -- | -- | -- | -- | -- TRUE | 512 | 2 | 192 | None | 22.45 | 6323.2 | 109.18 TRUE | 1024 | 2 | 192 | None | 24.73 | 9888.1 | 145.64 TRUE | 2048 | 2 | 192 | None | 29.27 | 19794.1 | 157.51 TRUE | 4096 | 2 | 192 | None | 38.36 | 44040.8 | 163.15 TRUE | 8192 | 2 | 192 | None | 56.63 | 113715.5 | 159.75 TRUE | 16384 | 2 | 192 | None | OOM | OOM | OOM TRUE | 512 | 2 | 192 | Full | 20.8 | 8606.4 | 104.65 TRUE | 1024 | 2 | 192 | Full | 21.4 | 12837.4 | 146.48 TRUE | 2048 | 2 | 192 | Full | 22.58 | 26172.6 | 155.8 TRUE | 4096 | 2 | 192 | Full | 25 | 58775.6 | 160.3 TRUE | 8192 | 2 | 192 | Full | 29.85 | 151869.1 | 157.44 TRUE | 16384 | 2 | 192 | Full | 36.95 | 442532.3 | 153.86 TRUE | 32768 | 2 | 192 | Full | 50.72 | 1440645.7 | 150.79 FALSE | 512 | 2 | 192 | None | 23.18 | 6138.9 | 110.4 FALSE | 1024 | 2 | 192 | None | 28.24 | 10404.8 | 137.88 FALSE | 2048 | 2 | 192 | None | 44.46 | 22679.3 | 137.47 FALSE | 4096 | 2 | 192 | None | OOM | OOM | OOM FALSE | 512 | 2 | 192 | Selective | 22.18 | 6840.5 | 100.92 FALSE | 1024 | 2 | 192 | Selective | 24.18 | 11216.9 | 128.39 FALSE | 2048 | 2 | 192 | Selective | 28.67 | 25446.1 | 122.52 FALSE | 4096 | 2 | 192 | Selective | 38.95 | 71740 | 100.16 FALSE | 8192 | 2 | 192 | Selective | 65.74 | 256122 | 70.95 FALSE | 16384 | 2 | 192 | Selective | OOM | OOM | OOM FALSE | 512 | 2 | 192 | Full | 20.8 | 8109.4 | 111.06 FALSE | 1024 | 2 | 192 | Full | 21.44 | 13569.5 | 138.58 FALSE | 2048 | 2 | 192 | Full | 23.3 | 29930.3 | 136.24 FALSE | 4096 | 2 | 192 | Full | 28.06 | 81514.7 | 115.58 FALSE | 8192 | 2 | 192 | Full | 46.4 | 276012.9 | 86.63 FALSE | 16384 | 2 | 192 | Full | OOM | OOM | OOM

RaymondLi0 commented 1 year ago

Additional test currently running: training runs on 5k steps should give the same loss normal-attn model VS flash-attn VS flash-attn with TP and SP

Screen Shot 2023-03-24 at 7 02 37 PM