sustcsonglin / flash-linear-attention

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
MIT License
784 stars 45 forks source link

training efficiency of GLA #30

Closed pengzhangzhi closed 2 weeks ago

pengzhangzhi commented 3 weeks ago

Hi, great work!! I am comparing the GLA and the attention (from pytorch) in diffusion model training. The GLA seems quite slower than the attention, even worse in the first epoch. Would love to know why and any solution we can take. image

yzhangcs commented 3 weeks ago

@pengzhangzhi Hello, may I know your training settings: did you train GLA with bf16/fp16? What's the accurate training throughput (toks/gpu/s) of the both?

yzhangcs commented 3 weeks ago

FYI, below is the training throughput tested on a single H100

image
pengzhangzhi commented 3 weeks ago

Thanks for the context!! I don't have the exact throughput but the settings are: sequence length is 128, batch size is 32, and the model is ~22M. I replaced the transformer block in the baseline with the GLA block and it results in 25% of param overhead. I use the full precision.

yzhangcs commented 3 weeks ago

@pengzhangzhi I see, I believe the model is too small, so the additional FLOPs are not ignorable compared to models >340M. Also, fla may not show its superiority with the input sequence length of 128. As you can see from the figure, the overall throughput of fla can not outperform attn for small lengths.

pengzhangzhi commented 3 weeks ago

Yes, the advantage I observed is during the training, the FLA-based model consumes less GPU memories thus enabling a larger batch size on the same hardware. But it's a ich that the inference of FLA model is quite slow. It might be disadvantageous for the downstream applications without even efficiency degradation. If I understand correctly.

yzhangcs commented 3 weeks ago

But it's a ich that the inference of FLA model is quite slow. It might be disadvantageous for the downstream applications without even efficiency degradation.

Could you give me precise comparisons, e.g., prefix length, model size. It would be better to provide runnable code.

pengzhangzhi commented 3 weeks ago

Yes, I would like to set up the comparison as follows: bs: 32 seq_len: 128, 512, 1024 model_size: ~ 20M, ~ 40M (by adjusting the num of blocks like 10 and 20)

Comparing the inference speed of the GLA and Attention-based models in the above configs. I don't have the results but my exps are GLA takes a bit of time weirdly. Maybe bootstrapping something in the beginning everything you call it? This algins with the training where I sense that the first epoch takes way longer than the subsequent epochs.

yzhangcs commented 3 weeks ago

@pengzhangzhi

Maybe bootstrapping something in the beginning everything you call it? This algins with the training where I sense that the first epoch takes way longer than the subsequent epochs.

not sure if its cuz triton needs some warmups first to auto tune the args to achieve better speed.

What about the hidden size? Maybe you can compare single attn layers only. Is the time you examined prefix-filling only, or including 1-by-1 decoding steps?

pengzhangzhi commented 3 weeks ago

FYI, I am using it as a BERT. Only forward pass.

yzhangcs commented 3 weeks ago

@pengzhangzhi Can you try this script https://github.com/sustcsonglin/flash-linear-attention/blob/main/benchmarks/ops/benchmark_fla.py, slightly revising the batch & hidden & seqlen to values you want

sustcsonglin commented 3 weeks ago

Yes, I would like to set up the comparison as follows: bs: 32 seq_len: 128, 512, 1024 model_size: ~ 20M, ~ 40M (by adjusting the num of blocks like 10 and 20)

Comparing the inference speed of the GLA and Attention-based models in the above configs. I don't have the results but my exps are GLA takes a bit of time weirdly. Maybe bootstrapping something in the beginning everything you call it? This algins with the training where I sense that the first epoch takes way longer than the subsequent epochs.

The sequence length is too small and thus in this case the I/O cost will dominate. Flashattention reads sequence only once while the chunkwise kernel for GLA requires multiple times of reading the sequence. It looks normal to me that GLA kernels is twice slower than FlashAttention due to higher I/O cost. I would expect fused_chunk mode to be faster under this setting