Closed woolpeeker closed 2 years ago
It seems like you are using fp32, could you try fp16?
Sure, I tested it with fp16.
with backward=False
timm finished 500 running, avg_time: 10.408471 ms light_seq finished 500 running, avg_time: 9.462291 ms
with backward=True
timm finished 500 running, avg_time: 31.718561 ms light_seq finished 500 running, avg_time: 30.036270 ms only 1.7 ms difference.
The test environment is: PyTorch version: 1.12.1 CUDA used to build PyTorch: 11.3 Python version: 3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0] (64-bit runtime) lightseq==2.2.1
Do you have official test result of ViT between native pytorch and lightseq?
We have tested on 8xA100, and this is the result.
BTW, could you tell me the card you are using, I will check if I can reproduce your results.
I used a single A100-80G GPU
I used a single A100-80G GPU
Hi, I tested fp16 precision of your script. Below is the result:
I ran it for multiple times, and the results are the same.
I also ran other dims and batch_sizes, and all the results show that lightseq is faster.
Thanks for testing. The lightseq result is the same with mine, around 9.5 ms in our machine, the original timm layer is a little faster than yours. Ours is around 10.4 ms. Lightseq is always faster than original layer, but the gap is smaller in our machine.
I will close this issue.
BTW, do you have plan to implement the flashAttention in lightseq? I heard it is much faster than previous methods.
hi @woolpeeker I'm also trying to integrate lightseq into my project with timm ViT model. After switching to the lightseq layer, can we still load the same pretrained model weights?
yes, you just need to organize the weight tensor following lightseq' docs. I have test it. the result can align to timm
@woolpeeker Thanks! This saves much time for me.
hi @woolpeeker I followed the doc and integrate lightseq transformer layer into timm ViT but the speed improvement is trivial. I'm wondering if you can observe any speedup as I'm using same GPU (A100) and the above snippet code gives me exactly same results as yours. I guess the relative improvement heavily depends on the batch size or other hyperparameters.
BTW, switching to FlashAttention could give me a speed boost of around 10% for the whole timm ViT model and for a single attention block, the speedup is 45%. I hope this could help.
I use the code below to test the vit block speed. The output shows the speed is almost the same between pytorch and lightseq
Did I missed something?
Output for forward only:
The output for forward + backward: