UCSC-VLAA / CLIPA

[NeurIPS 2023] This repository includes the official implementation of our paper "An Inverse Scaling Law for CLIP Training"
Apache License 2.0
298 stars 12 forks source link
contrastive-learning deep-learning foundation-models multimodal-learning neurips-2023 pytorch zero-shot-classification zero-shot-learning

An Inverse Scaling Law for CLIP Training

This repo contains official Pytorch and JAX implementation of CLIPA in our paper: An Inverse Scaling Law for CLIP Training

Overview of the Inverse Scaling Law: larger image/text encoders enable training with fewer image/text tokens while maintaining competitive performance

📰 News

[2023.10.4] We have achieved a successful scale-up of our model to bigG/14, attaining an impressive 83.0% zero-shot top-1 accuracy on the ImageNet-1K dataset. For the detailed training configuration, please refer to the t5x branch. Additionally, you can access the pre-trained and fine-tuned weights for both JAX and PyTorch version in the Google Drive.

[2023.9.21] Our paper is accepted by NeurIPS 2023!

[2023.6.16] We release CLIPA-v2. Compared to the prior best publicly available CLIP model, our CLIPA-v2 can be trained significantly faster and yields stronger performance. Our best model is H/14@336x336 on DataComp-1B with an accuracy of 81.8, and its estimated training cost is <$15k!

[Note] All of our CLIPA-v2 models were trained on TPU using our JAX codebase. We followed the same pre-training process as CLIPA-v1, but with a more efficient fine-tuning strategy. In order to replicate our results, we have provided the training configuration (e.g. the H-14 model in this folder here), along with the pre-trained weights, configuration, and logs, which can be found here.

data Schedule GPU Hours Estimated Cost zero-shot IN-1K model weight
H/14 LAION-2B 12.8B@84 + 512M@224 + 128M@336 8640 $13613 79.1 PyTorch / JAX
L/14 DataCOMP-1B 12.8B@84 + 512M@224 +128M@336 4520 $7124 80.3 PyTorch / JAX
H/14 DataCOMP-1B 12.8B@84 + 512M@224 + 128M@336 8640 $13613 81.8 PyTorch / JAX
bigG/14 DataCOMP-1B 12.8B@84 + 512M@224 + 128M@336 23742 $39056 83.0 PyTorch / JAX

Our CLIPA-v2’s GPU hour is estimated using an 8-A100 80GB GPU machine on Google Cloud. The corresponding training cost is estimated based on 80GB A100’s cloud pricing.

Introduction

CLIP, the first foundation model that connects images and text, has enabled many recent breakthroughs in computer vision. However, its associated training cost is prohibitively high, imposing a significant barrier to its widespread exploration. In this paper, we present a surprising finding that there exists an inverse scaling law for CLIP training, whereby the larger the image/text encoders used, the shorter the sequence length of image/text tokens that can be applied in training. Moreover, we showcase that the strategy for reducing image/text token length plays a crucial role in determining the quality of this scaling law.

As a result of this finding, we are able to successfully train CLIP even by using academic resources. For example, on an A100 eight-GPU server, our CLIP models achieve zero-shot top-1 ImageNet accuracies of 63.2% in about 2 days, 67.8% in about 3 days, and 69.3% in about 4 days. By reducing the computation barrier associated with CLIP, we hope to inspire more research in this field, particularly from academics.

TPU Usage

Our experiments are conducted on both GPUs and TPUs. Both the JAX and PyTorch implementations enable TPU training. But how to gain access and setup TPU machines? Check this brief doc. In a nutshell, you can access TPU machines on Google Cloud for free!

License

This project is under the Apache 2.0 License.

Acknowledgement

The jax repo is built on big vision, and the pytorch repo is built on OpenCLIP. We've also borrowed some code from TIMM and MAE. Many thanks to the awesome works from the open-source community!

We are also very grateful that this work is supported by a gift from Open Philanthropy, TPU Research Cloud (TRC) program, and Google Cloud Research Credits program.

Citation

@inproceedings{li2023clipa,
      title={An Inverse Scaling Law for CLIP Training}, 
      author={Xianhang Li and Zeyu Wang and Cihang Xie},
      booktitle={NeurIPS},
      year={2023},
}
@article{li2023clipav2,
      title={CLIPA-v2: Scaling CLIP Training with 81.1% Zero-shot ImageNet Accuracy within a $10,000 Budget; An Extra $4,000 Unlocks 81.8% Accuracy}, 
      author={Xianhang Li and Zeyu Wang and Cihang Xie},
      journal={arXiv preprint arXiv:2306.15658},
      year={2023},
}

Contact

If you have any questions, please feel free to raise an issue or contact us directly: Xianhang Li: xli421@ucsc.edu; Zeyu Wang: zwang615@ucsc.edu