openai / CLIP

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
MIT License
24.79k stars 3.21k forks source link

Influence of batch size of training convergence #25

Open CDitzel opened 3 years ago

CDitzel commented 3 years ago

According to your paper you use a large batch size of ~32k samples which means that the raw untrained network initially has a chance of ~1/32k of predicting the correct pair.

I am wondering, how the convergence/learning process would differ, if instead a binary classification problem was formulated and the network would be presented with matching text/image pairs and non-matching pairs alternatingly and be tasked with predicting whether those samples are actually in agreement or not.

In other words, what does the softmax over ~32k entries prior to cross entropy calculation bring to the table which cannot be achieved more conveniently by using sigmoid and binary cross entropy to predict matching/non-matching pairs. As a side effect this would also abolish the dependence on the batch size which seems to be rather crucial?

thoppe commented 3 years ago

Not the authors (and curious about their response), but CLIP is trained with contrastive loss (vs say a binary prediction problem suggested). As far as I've seen contrastive learning is really great for learning representations of features whereas a small-batch binary problems does well for learning how to classify. As they are different objectives, the metrics optimize over different things.

Newmu commented 3 years ago

Great question, we looked into this and related objectives a bit during preliminary research before settling on the training objective described in the paper.

We have to at least run encoders to calculate feature embeddings for the matching text / image pairs. This step is the vast majority of compute cost for training. The pairwise combinations of all image-text feature interactions are then cheap to compute on top of this, because they can re-use these already computed embeddings and only involve calculating a single additional inner product of relatively low-dimensional vectors per pair.

However, we found that convergence was an issue if we trained with a binary cross entropy loss, instead of categorical cross loss, over all pairwise losses. We had to modify BCE (manual downweighting to prevent negative pairs from dominating the loss) or "warm up" from a small amount of negatives. For the warmup version we started with 1 which is similar to what you propose and kept adding an additional one such that accuracy over the set of choices was always ~50%. These variants worked similarly to CCE in small scale initial experiments but we felt they were kind of hacky and settled on using the cross entropy loss over all pairwise combinations because softmax cross entropy is so common in image classification and our goal was primarily zero-shot image classification.

Other common alternatives in the literature like triplet loss, or variants of it like in VSE++ seemed to slightly underperform due to using only a sparse subset of loss terms though we didn't extensively tune them or investigate this in more detail.

If you want to get "crazy" you can modify CLIP to try to directly optimize for the intuitive batch pairing objective using something like sinkhorn instead of the symmetric cross entropy loss we use. I got this idea from a prior co-worker, Tim Salimans, who used it for Optimal Transport for GANs. We found it achieves slightly better representation learning results (~1%ish ImageNet top-1 boost on linear probes) for equivalent amount of data seen but the improvement wasn't justifiable in terms of the wallclock slow down it incurred. It also breaks away from the clean interpretation of CLIP pre-training optimizing a proxy to standard image classification so we were unsure how it would do for zero-shot transfer. Someone could probably write a paper studying this if they want, though!

After we settled on our training objective, however, we became aware of some work that has questioned the choice of CCE for image classification. For instance, see discussion in Are we done with ImageNet?. This area (changing / studying performance of different training objectives) is a great avenue for future work. Ideally we would have written up our work here in the paper but it was very preliminary / non-rigorous.

KeremTurgutlu commented 3 years ago

Great discussion! I am a bit confused regarding this part from the paper:

The calculation of embedding similarities was also sharded with individual GPUs computing
only the subset of the pairwise similarities necessary for their local batch of embeddings

Does this mean that pairwise logit calculations were only restricted to the batch for a given GPU? So, if we distribute training to have a batch size of 128 per GPU then it means the loss will be calculated from the pairwise comparison of 128 samples only. For example, AFAIK in SimCLR they use a large batch to improve contrastive learning:

To keep it simple, we do not train the model with a memory
bank (Wu et al., 2018; He et al., 2019). Instead, we vary
the training batch size N from 256 to 8192. A batch size
of 8192 gives us 16382 negative examples per positive pair
from both augmentation views.

or MoCo uses Queues to increase the number of negative samples per loss calculation.

Is there any such mechanism in CLIP to increase the number of effective samples during loss calculation or is it regular distributed training with losses from 128 samples per GPU?

Edit: My question above seems like duplicate of #29.

Thanks for this amazing work once again!

If you want to get "crazy" you can modify CLIP to try to directly optimize for the intuitive batch pairing objective using something like sinkhorn instead of the symmetric cross entropy loss we use.

This similar approach is used in SwAV and also in my own experiments it learns better visual representations compared to SimCLR or BYOL. Though, that improvement might be highly related to multicrop views too and I am not sure how it would fit into multimodal representation learning setting like CLIP.

LOOKCC commented 3 years ago

@KeremTurgutlu I found the same question when I want to train CLIP with 8 GPUs. I found that if my total batch size is 128, then pytorch will split it to 8x32, then i will have eight 32x32 pair. But if i use one GPU and set batch size to 128, i will have one 128x128 pair. As we know, eight 32x32 is not the same with one 128x128, it lost many negative pairs. And I don't know which one is suitable for CLIP training. What's more, my loss didn't decrease during training(may be some wrong in my code),it's too hard to do training.

KeremTurgutlu commented 3 years ago

@KeremTurgutlu I found the same question when I want to train CLIP with 8 GPUs. I found that if my total batch size is 128, then pytorch will split it to 8x32, then i will have eight 32x32 pair. But if i use one GPU and set batch size to 128, i will have one 128x128 pair. As we know, eight 32x32 is not the same with one 128x128, it lost many negative pairs. And I don't know which one is suitable for CLIP training. What's more, my loss didn't decrease during training(may be some wrong in my code),it's too hard to do training.

If you have 8 gpus you can surely fit 128 samples on each of them independently, if 128 already fits to a single gpu in your experiments. Maybe your current code splits like that.

My question was rather about increasing the number of effective negatives during loss calculation beyond what can be fed to a single gpu. One option is to gather all text and image embeddings from all the gpus for each gpu, this is what original CLIP training does as far as I understand from other issues. This is also probably how SimCLR is trained using the same InfoNCE loss.

Unfortunately, not everyone has access to multi-node gpu machines, even getting an 8 gpu machine is very expensive but let's assume you have access to such machine for this kind of training. Still, you would be capped with maximum of 128*8 = 1024 samples. MoCo addresses this exact issue by using queues, you can set the number of samples in the queue to a much more higher number than 1024. My question was more about whether using a pair of query and key encoder for both images and texts would make sense? Like, 1 image query encoder, 1 image key encoder, 1 text query encoder, and 1 text image encoder. Then similarities can be calculated between image query embedding - text key embedding and also text query embedding - image key embedding.

LOOKCC commented 3 years ago

@KeremTurgutlu It's no doubt that using queues is a good way if whose GPUs are not very good. I know the way contrastive learning(such as MoCo) do. If train CLIP in it's paper's way, it's a task of n_class(batch_size)-classification If train CLIP in contrastive's way, it's a task of Binary-classification(related or interrelated) This question is really interesting. I spent some time thinking about this problem, I think Binary-classification may not work well as n_class-classification do. Because in contrastive learning, they are all images(in same domain), so it's easy to judgment whether matching or not, but in CLIP, one is image, the other is text, it's not easy to directly judgment whether matching or not, it need some negative samples to make a relative distinction. In this way CLIP make the problem simple. This is just my personal idea, everyone is welcome to discuss.

raijinspecial commented 3 years ago

@KeremTurgutlu I found the same question when I want to train CLIP with 8 GPUs. I found that if my total batch size is 128, then pytorch will split it to 8x32, then i will have eight 32x32 pair. But if i use one GPU and set batch size to 128, i will have one 128x128 pair. As we know, eight 32x32 is not the same with one 128x128, it lost many negative pairs. And I don't know which one is suitable for CLIP training. What's more, my loss didn't decrease during training(may be some wrong in my code),it's too hard to do training.

I also encountered this issue of loss getting stuck while trying to train CLIP on a single 8-gpu machine. By sacrificing some image resolution (resize down to 128) I was able to start training with a batch size of 1024 and after ~500 steps the loss started to improve. I tried many different configurations of layers and widths for batches between 32-256 and none of them could learn.

I am curious if you or Kerem have successfully tried any of the tricks mentioned above to increase the effective size of the contrastive pairs.

KeremTurgutlu commented 3 years ago

I am curious if you or Kerem have successfully tried any of the tricks mentioned above to increase the effective size of the contrastive pairs.

So far, I tried with both 128 bs 224 res and 256 bs 224 res (with grad checkpointing) on a single GPU. They are both trained for 1 epoch with a dataset of 4.4 million image-text pairs, ~40k steps. Models seem to converge. I am now going to try my distributed InfoNCE loss implementation and with 256*4 effective batch size to see if it improves.

dragen1860 commented 3 years ago

I am curious if you or Kerem have successfully tried any of the tricks mentioned above to increase the effective size of the contrastive pairs.

So far, I tried with both 128 bs 224 res and 256 bs 224 res (with grad checkpointing) on a single GPU. They are both trained for 1 epoch with a dataset of 4.4 million image-text pairs, ~40k steps. Models seem to converge. I am now going to try my distributed InfoNCE loss implementation and with 256*4 effective batch size to see if it improves.

Hi, KeremTurgutlu, I noticed you have implemented your training code and even get startted to train on own machine, which is really cool!! I get stuck in how to implement the training code by myself. Could you kindly share it with me ? I can not wait to reproduce this amazaing work on your training code! thank you. here is my email addr: liangqu.long@gmail.com

realTaki commented 2 years ago

@KeremTurgutlu I found the same question when I want to train CLIP with 8 GPUs. I found that if my total batch size is 128, then pytorch will split it to 8x32, then i will have eight 32x32 pair. But if i use one GPU and set batch size to 128, i will have one 128x128 pair. As we know, eight 32x32 is not the same with one 128x128, it lost many negative pairs. And I don't know which one is suitable for CLIP training. What's more, my loss didn't decrease during training(may be some wrong in my code),it's too hard to do training.

I chose to just collect all embeddings from 8 GPUs, then make the similarity matrix and compute loss in CPU with a huge memory. I mean it's able to get a 128x128 similarity matrix in multiple GPUs. But it is really hard to run a 32k batch size, and in my tests, it seems that the small batch size does reduce the performance.

patrickvonplaten commented 11 months ago

BCE is now popularized in https://arxiv.org/pdf/2303.15343.pdf using pretty much exactly what was tried by the CLIP team before it seems:

At initialization, the heavy imbalance coming from the many negatives dominates the loss, leading to large initial optimization steps attempting to correct this bias. To alleviate this, we introduce an additional learnable bias term b similar to the temperature t. We initialize t0 and b to log 10 and −10 respectively. This makes sure the training starts roughly close to the prior and does not require massive over-correction. Algorithm 1 presents a pseudocode implementation of the proposed sigmoid loss for language image pre-training.

miguelalba96 commented 9 months ago

I'm trying to finetune the HF implementation of 336px (so 440M Params) with LoRA (4mill additional parameters are trained, the rest is frozen), I'm training using lighting fabric as it can be adapted easier to the OpenCLIP repo, but I cannot make the batch size greater than 8 per device, so 8 V100 x 8 = 64 of effective batch size, is this normal/expected? Or I should be able to fit a larger batch size in a single GPU ? Any help would be greatly appreciated :)