Closed pengzhangzhi closed 2 weeks ago
@pengzhangzhi Hello, may I know your training settings: did you train GLA with bf16/fp16? What's the accurate training throughput (toks/gpu/s) of the both?
FYI, below is the training throughput tested on a single H100
Thanks for the context!! I don't have the exact throughput but the settings are: sequence length is 128, batch size is 32, and the model is ~22M. I replaced the transformer block in the baseline with the GLA block and it results in 25% of param overhead. I use the full precision.
@pengzhangzhi I see, I believe the model is too small, so the additional FLOPs are not ignorable compared to models >340M. Also, fla may not show its superiority with the input sequence length of 128. As you can see from the figure, the overall throughput of fla can not outperform attn for small lengths.
Yes, the advantage I observed is during the training, the FLA-based model consumes less GPU memories thus enabling a larger batch size on the same hardware. But it's a ich that the inference of FLA model is quite slow. It might be disadvantageous for the downstream applications without even efficiency degradation. If I understand correctly.
But it's a ich that the inference of FLA model is quite slow. It might be disadvantageous for the downstream applications without even efficiency degradation.
Could you give me precise comparisons, e.g., prefix length, model size. It would be better to provide runnable code.
Yes, I would like to set up the comparison as follows: bs: 32 seq_len: 128, 512, 1024 model_size: ~ 20M, ~ 40M (by adjusting the num of blocks like 10 and 20)
Comparing the inference speed of the GLA and Attention-based models in the above configs. I don't have the results but my exps are GLA takes a bit of time weirdly. Maybe bootstrapping something in the beginning everything you call it? This algins with the training where I sense that the first epoch takes way longer than the subsequent epochs.
@pengzhangzhi
Maybe bootstrapping something in the beginning everything you call it? This algins with the training where I sense that the first epoch takes way longer than the subsequent epochs.
not sure if its cuz triton needs some warmups first to auto tune the args to achieve better speed.
What about the hidden size? Maybe you can compare single attn layers only. Is the time you examined prefix-filling only, or including 1-by-1 decoding steps?
FYI, I am using it as a BERT. Only forward pass.
@pengzhangzhi Can you try this script https://github.com/sustcsonglin/flash-linear-attention/blob/main/benchmarks/ops/benchmark_fla.py, slightly revising the batch & hidden & seqlen to values you want
Yes, I would like to set up the comparison as follows: bs: 32 seq_len: 128, 512, 1024 model_size: ~ 20M, ~ 40M (by adjusting the num of blocks like 10 and 20)
Comparing the inference speed of the GLA and Attention-based models in the above configs. I don't have the results but my exps are GLA takes a bit of time weirdly. Maybe bootstrapping something in the beginning everything you call it? This algins with the training where I sense that the first epoch takes way longer than the subsequent epochs.
The sequence length is too small and thus in this case the I/O cost will dominate. Flashattention reads sequence only once while the chunkwise kernel for GLA requires multiple times of reading the sequence. It looks normal to me that GLA kernels is twice slower than FlashAttention due to higher I/O cost. I would expect fused_chunk
mode to be faster under this setting
Hi, great work!! I am comparing the GLA and the attention (from pytorch) in diffusion model training. The GLA seems quite slower than the attention, even worse in the first epoch. Would love to know why and any solution we can take.![image](https://github.com/sustcsonglin/flash-linear-attention/assets/59241275/fc2d0f7f-8634-4771-aeac-a54d9653fa7e)