Closed pphuc25 closed 5 months ago
Great idea! It is worth to try. Maybe we will release it with the HUGE version.
Initially, I intended to use Xformers to accelerate training because Xformers support more GPUs compared to Flash Attention while maintaining similar efficiency. However, Xformers require tensor shapes that are multiples of 8. Since the text length in CLIP is 77, it does not meet this requirement. Therefore, I will proceed to try Flash Attention.
According to this issue, Flash Attention currently does not support custom attention masks. However, OpenCLIP's text requires irregular attention masks. Additionally, the OpenCLIP team seems to be exploring the integration of Flash Attention here. Perhaps we can wait for a while to see if any developments arise.
cool information, thank you for provided
however, do you think the effort in extend the context length is worth?, when multiples of 8 there're so much advantages such as more suitable with TPU, mod 2 == 0 @LinB203
In fact, the token length of 77 is not something I have set, but rather a setting of the CLIP model. I have not changed it because I do not want to modify any pretrained weights. I am not sure if changing it to a multiple of 8 would have an impact on performance. Additionally, to reduce training costs, I have tried incorporating grad_checkpoint
, which has reduced the memory consumption from 31g to 11g on V100 under the same batch size. The code will be released along with the HUGE version.
however, do you think the effort in extend the context length is worth?, when multiples of 8 there're so much advantages such as more suitable with TPU, mod 2 == 0 @LinB203
Larger models can now be trained by turning on --grad-checkpointing
.
As explore the code, and in my knowledge (please correct if there are something wrong), the current code do not have flash attention in training but instead that the vanilla attention I think flash attention is a low hanging fruit when training and eval will be faster but still the same result Do you have any plan to apply flash attention to your code?