PKU-YuanGroup / LanguageBind

【ICLR 2024🔥】 Extending Video-Language Pretraining to N-modality by Language-based Semantic Alignment
https://arxiv.org/abs/2310.01852
MIT License
549 stars 44 forks source link

Add flash attention 2 #19

Closed pphuc25 closed 5 months ago

pphuc25 commented 5 months ago

As explore the code, and in my knowledge (please correct if there are something wrong), the current code do not have flash attention in training but instead that the vanilla attention I think flash attention is a low hanging fruit when training and eval will be faster but still the same result Do you have any plan to apply flash attention to your code?

LinB203 commented 5 months ago

Great idea! It is worth to try. Maybe we will release it with the HUGE version.

LinB203 commented 5 months ago

Initially, I intended to use Xformers to accelerate training because Xformers support more GPUs compared to Flash Attention while maintaining similar efficiency. However, Xformers require tensor shapes that are multiples of 8. Since the text length in CLIP is 77, it does not meet this requirement. Therefore, I will proceed to try Flash Attention.

LinB203 commented 5 months ago

According to this issue, Flash Attention currently does not support custom attention masks. However, OpenCLIP's text requires irregular attention masks. Additionally, the OpenCLIP team seems to be exploring the integration of Flash Attention here. Perhaps we can wait for a while to see if any developments arise.

pphuc25 commented 5 months ago

cool information, thank you for provided

pphuc25 commented 5 months ago

however, do you think the effort in extend the context length is worth?, when multiples of 8 there're so much advantages such as more suitable with TPU, mod 2 == 0 @LinB203

LinB203 commented 5 months ago

In fact, the token length of 77 is not something I have set, but rather a setting of the CLIP model. I have not changed it because I do not want to modify any pretrained weights. I am not sure if changing it to a multiple of 8 would have an impact on performance. Additionally, to reduce training costs, I have tried incorporating grad_checkpoint, which has reduced the memory consumption from 31g to 11g on V100 under the same batch size. The code will be released along with the HUGE version.

however, do you think the effort in extend the context length is worth?, when multiples of 8 there're so much advantages such as more suitable with TPU, mod 2 == 0 @LinB203

LinB203 commented 5 months ago

Larger models can now be trained by turning on --grad-checkpointing.