Closed rakkit closed 3 months ago
And extra benchmark on H100 (SXM5 94GB)
Model | Batch-Size | Seq_len | Max-memory (GB) | Throughput (tokens/s)On A100 | Throughput H100 |
---|---|---|---|---|---|
GLA | 8 | 512 | 14.77 | 14959.27 | 31704.30 |
1024 | 22.88 | 17467.75 | 37256.27 | ||
2048 | OOM | 41293.03 | |||
GSA | 8 | 512 | 16.07 | 14960.99 | 31672.98 |
1024 | 24.35 | 17674.00 | 37659.39 | ||
2028 | OOM | 41753.57 | |||
HGRN | 8 | 512 | 16.9 | 16382.00 | 35731.41 |
1024 | 26.15 | 19500.58 | 42272.09 | ||
2048 | OOM | 49234.70 | |||
retnet | 8 | 512 | 15.13 | 13369.01 | 27383.20 |
1024 | 22.66 | 15437.14 | 31250.06 | ||
2048 | 37.75 | 16445.58 | 34090.15 | ||
transformer | 8 | 512 | 13.98 | 17851.42 | 40468.24 |
1024 | 20.30 | 20994.52 | 47119.66 | ||
2048 | 32.96 | 21807.02 | 49627.74 | ||
Mamba | 8 | 512 | 15.40 | 10385.55 | 19980.05 |
1024 | 22.72 | 11230.94 | 21674.81 | ||
2048 | 37.36 | 12151.64 | 23410.73 | ||
Samba | 8 | 512 | 13.77 | 16475.11 | 33476.98 |
1024 | 19.56 | 18470.86 | 38945.52 | ||
2048 | 31.18 | 19850.30 | 42370.31 |
@rakkit Hi, thank you for reporting the bugs. The new commits have fixed those you mentioned. However, HGRN2 and LinearAttn are normal for me. Please check your triton/causal-conv1d versions.
@rakkit Hi, thank you for reporting the bugs. The new commits have fixed those you mentioned. However, HGRN2 and LinearAttn are normal for me. Please check your triton/causal-conv1d versions.
I've fixed the HGRN2 bug in https://github.com/sustcsonglin/flash-linear-attention/commit/a7bb4b7f71bec43d72a7486436d4b837b44e4333: AMP will secretly convert keys to float32 due to the use of sigmoid. Now it should be good
Hi, @yzhangcs @sustcsonglin. Thanks for fixing the bug so quickly. I can confirm all models are working now.
There is a minor issue in benchmark code. The seq_len
is not passed to config
, so config.max_position_embeddings
will be the default. For models such as transformers will fail in long-sequence benchmark (seq > max_position_embeddings)
The following are the full benchmark results for NVIDIA A100-SXM4-40GB and NVIDIA H100 PCIe-80GB.
(I plot BS=1 Only here)
models | BS | Seq-len | A100 | H100 |
---|---|---|---|---|
samba | 1 | 32768 | 28904.69 | |
retnet | 1 | 32768 | 22108.37 | |
delta-net | 1 | 32768 | 20419.72 | |
mamba | 1 | 32768 | 18535.8 | |
transformer | 1 | 32768 | 12216.43 | |
samba | 2 | 16384 | 27565.2 | |
retnet | 2 | 16384 | 23154.44 | |
delta-net | 2 | 16384 | 21761.01 | |
transformer | 2 | 16384 | 18647.61 | |
mamba | 2 | 16384 | 16121.99 | |
hgrn | 1 | 16384 | 30473.15 | |
samba | 1 | 16384 | 20055.28 | 27931.16 |
gsa | 1 | 16384 | 25442.21 | |
hgrn2 | 1 | 16384 | 24075.14 | |
retnet | 1 | 16384 | 21964.05 | |
linear-attn | 1 | 16384 | 21414 | |
delta-net | 1 | 16384 | 20194.1 | |
rwkv6 | 1 | 16384 | 18773.85 | |
transformer | 1 | 16384 | 18242.55 | |
mamba | 1 | 16384 | 13811.72 | 18017 |
samba | 4 | 8192 | 27906.56 | |
transformer | 4 | 8192 | 25004.98 | |
retnet | 4 | 8192 | 24045.87 | |
delta-net | 4 | 8192 | 22723.17 | |
mamba | 4 | 8192 | 16203.86 | |
hgrn | 2 | 8192 | 31345.5 | |
samba | 2 | 8192 | 19210.89 | 27332.59 |
gsa | 2 | 8192 | 26957.65 | |
hgrn2 | 2 | 8192 | 24674.07 | |
transformer | 2 | 8192 | 24419.43 | |
retnet | 2 | 8192 | 16247.82 | 22721.79 |
linear-attn | 2 | 8192 | 22010.4 | |
delta-net | 2 | 8192 | 21369.47 | |
rwkv6 | 2 | 8192 | 18945.29 | |
mamba | 2 | 8192 | 11941.12 | 15749.19 |
hgrn | 1 | 8192 | 18675.9 | 27612.78 |
samba | 1 | 8192 | 18699.51 | 26528.8 |
gsa | 1 | 8192 | 16001.26 | 23310.63 |
transformer | 1 | 8192 | 22464.27 | |
hgrn2 | 1 | 8192 | 15501.05 | 22117.39 |
retnet | 1 | 8192 | 14617.58 | 20519.9 |
linear-attn | 1 | 8192 | 14021.22 | 19380.21 |
delta-net | 1 | 8192 | 13352.6 | 18804.9 |
rwkv6 | 1 | 8192 | 17398.72 | |
mamba | 1 | 8192 | 13288.27 | 17014.34 |
transformer | 8 | 4096 | 30320.49 | |
samba | 8 | 4096 | 28164.83 | |
retnet | 8 | 4096 | 24162.65 | |
delta-net | 8 | 4096 | 22879.71 | |
mamba | 8 | 4096 | 16240.19 | |
hgrn | 4 | 4096 | 31571.14 | |
transformer | 4 | 4096 | 29370.23 | |
gsa | 4 | 4096 | 27980.19 | |
samba | 4 | 4096 | 19486.16 | 27797.24 |
hgrn2 | 4 | 4096 | 24948.31 | |
retnet | 4 | 4096 | 16406.06 | 23430.21 |
linear-attn | 4 | 4096 | 22187.43 | |
delta-net | 4 | 4096 | 22106.26 | |
rwkv6 | 4 | 4096 | 19107.01 | |
mamba | 4 | 4096 | 12146.53 | 15883.88 |
hgrn | 2 | 4096 | 19252.6 | 28482.19 |
transformer | 2 | 4096 | 26838.27 | |
samba | 2 | 4096 | 18017.55 | 25800.91 |
gsa | 2 | 4096 | 17027 | 24598.84 |
hgrn2 | 2 | 4096 | 15847.4 | 22766.22 |
retnet | 2 | 4096 | 15268.96 | 21436.01 |
linear-attn | 2 | 4096 | 14599.34 | 20324.97 |
delta-net | 2 | 4096 | 14053.8 | 19777.87 |
rwkv6 | 2 | 4096 | 17706.11 | |
mamba | 2 | 4096 | 11631.5 | 15203.35 |
hgrn | 1 | 4096 | 15768.39 | 23586.86 |
samba | 1 | 4096 | 16493.58 | 23347.85 |
transformer | 1 | 4096 | 23195.83 | |
hgrn2 | 1 | 4096 | 13383.31 | 19303.13 |
retnet | 1 | 4096 | 12748.79 | 18549.01 |
gsa | 1 | 4096 | 13751.03 | 18382.57 |
linear-attn | 1 | 4096 | 12239.79 | 17267.57 |
delta-net | 1 | 4096 | 11686 | 16246.66 |
rwkv6 | 1 | 4096 | 10452.14 | 15948.46 |
mamba | 1 | 4096 | 11841.48 | 15932.17 |
transformer | 16 | 2048 | 33974.39 | |
samba | 16 | 2048 | 28634.4 | |
retnet | 16 | 2048 | 24150 | |
delta-net | 16 | 2048 | 22971.42 | |
mamba | 16 | 2048 | 16222.58 | |
transformer | 8 | 2048 | 21955.1 | 32656.87 |
hgrn | 8 | 2048 | 31816.45 | |
gsa | 8 | 2048 | 28157.97 | |
samba | 8 | 2048 | 19838.32 | 27929.99 |
hgrn2 | 8 | 2048 | 25083.31 | |
retnet | 8 | 2048 | 16564.31 | 23576.51 |
linear-attn | 8 | 2048 | 22438.17 | |
delta-net | 8 | 2048 | 22272.51 | |
rwkv6 | 8 | 2048 | 19198.61 | |
mamba | 8 | 2048 | 12176.65 | 15861.45 |
transformer | 4 | 2048 | 20264.96 | 29556.21 |
hgrn | 4 | 2048 | 19429.87 | 28679.4 |
samba | 4 | 2048 | 18297.54 | 26065.31 |
gsa | 4 | 2048 | 17566.37 | 25574.87 |
hgrn2 | 4 | 2048 | 15941.95 | 22907.09 |
linear-attn | 4 | 2048 | 14663.66 | 20819.64 |
delta-net | 4 | 2048 | 14558.07 | 20432.8 |
retnet | 4 | 2048 | 15405.69 | 20195.56 |
rwkv6 | 4 | 2048 | 17712.64 | |
mamba | 4 | 2048 | 11674.9 | 15177.81 |
transformer | 2 | 2048 | 16961.03 | 25506.14 |
hgrn | 2 | 2048 | 16207.16 | 24051.77 |
samba | 2 | 2048 | 16119.07 | 22902.87 |
gsa | 2 | 2048 | 14500.91 | 21552.52 |
hgrn2 | 2 | 2048 | 13645 | 19613.37 |
retnet | 2 | 2048 | 13257.27 | 19379.32 |
linear-attn | 2 | 2048 | 12667.81 | 17915.54 |
delta-net | 2 | 2048 | 12222.73 | 17606.42 |
rwkv6 | 2 | 2048 | 10511.13 | 16026.85 |
mamba | 2 | 2048 | 10533.79 | 14058.06 |
transformer | 1 | 2048 | 13270.76 | 20094.04 |
hgrn | 1 | 2048 | 12584.37 | 18807.96 |
samba | 1 | 2048 | 13280.36 | 18106.8 |
retnet | 1 | 2048 | 10484.34 | 14925.46 |
linear-attn | 1 | 2048 | 9974.69 | 14619.53 |
gsa | 1 | 2048 | 10429.46 | 13602.88 |
mamba | 1 | 2048 | 9746.15 | 12487 |
hgrn2 | 1 | 2048 | 10932.83 | 12090.52 |
delta-net | 1 | 2048 | 9735.89 | 11350.95 |
rwkv6 | 1 | 2048 | 8658.36 | 7733.45 |
transformer | 32 | 1024 | 36088.91 | |
samba | 32 | 1024 | 29441.18 | |
retnet | 32 | 1024 | 23971.72 | |
delta-net | 32 | 1024 | 23042.36 | |
mamba | 32 | 1024 | 16431.25 | |
transformer | 16 | 1024 | 22978.64 | 34667.33 |
hgrn | 16 | 1024 | 31794.96 | |
samba | 16 | 1024 | 19938.37 | 28706.31 |
gsa | 16 | 1024 | 28277.14 | |
hgrn2 | 16 | 1024 | 25243.74 | |
retnet | 16 | 1024 | 16574.38 | 23523.83 |
linear-attn | 16 | 1024 | 22535.4 | |
delta-net | 16 | 1024 | 22520.52 | |
rwkv6 | 16 | 1024 | 19232 | |
mamba | 16 | 1024 | 16063.34 | |
transformer | 8 | 1024 | 21132.73 | 31095.91 |
hgrn | 8 | 1024 | 19370.2 | 28536.46 |
samba | 8 | 1024 | 18429.84 | 26916.36 |
gsa | 8 | 1024 | 17816.01 | 25694.83 |
hgrn2 | 8 | 1024 | 16036.75 | 23084.95 |
retnet | 8 | 1024 | 15528.16 | 22135.69 |
linear-attn | 8 | 1024 | 14763.62 | 20849.1 |
delta-net | 8 | 1024 | 14605.84 | 20556.72 |
rwkv6 | 8 | 1024 | 17765.78 | |
mamba | 8 | 1024 | 11251.76 | 15380.6 |
transformer | 4 | 1024 | 17611.25 | 26917.37 |
hgrn | 4 | 1024 | 16157.12 | 23891.95 |
samba | 4 | 1024 | 16246.94 | 23271.67 |
hgrn2 | 4 | 1024 | 13711.21 | 20287.58 |
gsa | 4 | 1024 | 14901.37 | 19226.52 |
retnet | 4 | 1024 | 13333.89 | 19125.69 |
linear-attn | 4 | 1024 | 12735.54 | 18367.85 |
delta-net | 4 | 1024 | 12592.95 | 18226.03 |
rwkv6 | 4 | 1024 | 10564.47 | 16148.39 |
mamba | 4 | 1024 | 10229.23 | 14218.26 |
samba | 2 | 1024 | 13023.25 | 17978.55 |
hgrn2 | 2 | 1024 | 11099.64 | 16062.79 |
hgrn | 2 | 1024 | 12677.8 | 15324.84 |
delta-net | 2 | 1024 | 9957.74 | 14141.77 |
linear-attn | 2 | 1024 | 10246.18 | 13032.14 |
retnet | 2 | 1024 | 10804.88 | 13003.05 |
rwkv6 | 2 | 1024 | 8714.04 | 12569.48 |
mamba | 2 | 1024 | 8589.29 | 11972.97 |
gsa | 2 | 1024 | 10758.39 | 10844.75 |
samba | 1 | 1024 | 9545.75 | 9991.64 |
retnet | 1 | 1024 | 6998.4 | 9003.62 |
hgrn | 1 | 1024 | 7720.77 | 7824.65 |
linear-attn | 1 | 1024 | 7059.6 | 7195.38 |
hgrn2 | 1 | 1024 | 6435.73 | 6144.78 |
mamba | 1 | 1024 | 6892.87 | 6058.78 |
delta-net | 1 | 1024 | 5650.93 | 5780.75 |
rwkv6 | 1 | 1024 | 4814.52 | 4774.23 |
gsa | 1 | 1024 | 5319.96 | 4432.29 |
transformer | 32 | 512 | 23491.27 | 35695.48 |
hgrn | 32 | 512 | 31913.77 | |
samba | 32 | 512 | 20188.05 | 28949.67 |
gsa | 32 | 512 | 28433.56 | |
hgrn2 | 32 | 512 | 25231.68 | |
retnet | 32 | 512 | 16544.75 | 23569.82 |
delta-net | 32 | 512 | 22614.96 | |
linear-attn | 32 | 512 | 22573.86 | |
rwkv6 | 32 | 512 | 19212.5 | |
mamba | 32 | 512 | 16269.5 | |
transformer | 16 | 512 | 21583.01 | 31914.11 |
hgrn | 16 | 512 | 19453.76 | 28698.38 |
samba | 16 | 512 | 18714.74 | 27012.99 |
gsa | 16 | 512 | 17800.15 | 25769.91 |
hgrn2 | 16 | 512 | 16091.78 | 23190.33 |
retnet | 16 | 512 | 15507.6 | 22123.85 |
linear-attn | 16 | 512 | 14759.27 | 20859.9 |
delta-net | 16 | 512 | 14658.35 | 20794.48 |
rwkv6 | 16 | 512 | 17757.09 | |
mamba | 16 | 512 | 15530.05 | |
transformer | 8 | 512 | 17959.34 | 26990.47 |
hgrn | 8 | 512 | 16262.18 | 23961.92 |
samba | 8 | 512 | 16445.8 | 23652.83 |
hgrn2 | 8 | 512 | 13781.91 | 20332.45 |
retnet | 8 | 512 | 13427.81 | 19785.93 |
linear-attn | 8 | 512 | 12819.35 | 18371.43 |
gsa | 8 | 512 | 15056.07 | 17805.26 |
delta-net | 8 | 512 | 12641.08 | 17178.97 |
rwkv6 | 8 | 512 | 10603.23 | 15986.87 |
mamba | 8 | 512 | 10396.3 | 14451.54 |
samba | 4 | 512 | 13159.76 | 18081.48 |
transformer | 4 | 512 | 13941.52 | 17548.94 |
retnet | 4 | 512 | 10853.05 | 15515.85 |
hgrn | 4 | 512 | 12714.64 | 15507.94 |
linear-attn | 4 | 512 | 10287.85 | 15261.04 |
delta-net | 4 | 512 | 10165.85 | 14407.68 |
hgrn2 | 4 | 512 | 11159.17 | 12823.84 |
rwkv6 | 4 | 512 | 8760.7 | 12710.08 |
mamba | 4 | 512 | 8676.28 | 12335.67 |
gsa | 4 | 512 | 10788.66 | 10845.14 |
hgrn | 2 | 512 | 7804.57 | 9850.05 |
transformer | 2 | 512 | 8624.15 | 8912.99 |
linear-attn | 2 | 512 | 6974.66 | 7168.91 |
retnet | 2 | 512 | 6943.15 | 7048.41 |
mamba | 2 | 512 | 6607.24 | 6737.07 |
hgrn2 | 2 | 512 | 6440.37 | 5409.62 |
delta-net | 2 | 512 | 5715.89 | 5383.6 |
gsa | 2 | 512 | 5339.63 | 5336.48 |
rwkv6 | 2 | 512 | 4889.76 | 4554 |
samba | 1 | 512 | 4964.62 | 5028.2 |
hgrn2 | 1 | 512 | 3215.79 | 4092.26 |
transformer | 1 | 512 | 4371.82 | 3979.84 |
hgrn | 1 | 512 | 3903.4 | 3925.22 |
linear-attn | 1 | 512 | 3556.24 | 3670.65 |
retnet | 1 | 512 | 3531.6 | 3574.41 |
mamba | 1 | 512 | 3484.71 | 3459.5 |
delta-net | 1 | 512 | 2884.54 | 2886.16 |
gsa | 1 | 512 | 2666.12 | 2700.98 |
rwkv6 | 1 | 512 | 2428.59 | 1955.17 |
@rakkit Thanks for your detailed benchmarks.
There is a minor issue in benchmark code. The seq_len is not passed to config, so config.max_position_embeddings will be the default. For models such as transformers will fail in long-sequence benchmark (seq > max_position_embeddings)
We will fix it soon, :>
Hi, Thanks for your great work. I ran benchmarks to test all modes' throughput and memory usage with code
flash-linear-attention/benchmarks/benchmark_training_throughput.py
. Some of them, unfortunately, failed.WORK:
FAILED
(For mamba, i notice FLA did not asking for
mamba_ssm
andcausal_conv1d
, but also did not raise any warning that it runs on slow-forward mode)Here are benchmark results, and the error info of the failed run is attached at the end.
Environment:
Delta-Net
HGRN2
Linear_attention
Rwkv6
OOM for BS=8