mlfoundations / open_clip

An open source implementation of CLIP.
Other
9.14k stars 908 forks source link

Handling Negative Pairs in Fine-Tuning of CLIP Models #869

Closed doramasma closed 2 months ago

doramasma commented 2 months ago

Hello guys!

First, I'd like to clarify my understanding of Contrastive Loss for CLIP models: If I'm not mistaken, when training a CLIP model, contrastive loss is employed, involving typical triplets consisting of positive and negative pairs derived from the images. Put simply, during training, the aim is to maximize the cosine similarity between correct image-caption vector pairs while minimizing the similarity scores between all incorrect pairs. Am I right?

So, I would like to be able to provide the negative Paris into my fine-tuning. In my experiment, the embeddings for both negative and positive samples must be well-distributed. However, following the guide, I noticed that it doesn't specify how to handle the negative samples. It seems that during training, the model utilizes the previous batch as the negative pairs. Am I missing something? Is this correct?

        # Now, ready to take gradients for the last accum_freq batches.
        # Re-do the forward pass for those batches, and use the cached features from the other batches as negatives.
        # Call backwards each time, but only step optimizer at the end.
        optimizer.zero_grad()
        for j in range(args.accum_freq):
            images = accum_images[j]
            texts = accum_texts[j]
            with autocast():
                model_out = model(images, texts)

                inputs_no_accum = {}
                inputs_no_accum["logit_scale"] = logit_scale = model_out.pop("logit_scale")
                if "logit_bias" in model_out:
                    inputs_no_accum["logit_bias"] = model_out.pop("logit_bias")

                inputs = {}
                for key, val in accum_features.items():
                    accumulated = accum_features[key]
                    inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:])

                losses = loss(**inputs, **inputs_no_accum, output_dict=True)
                del inputs
                del inputs_no_accum
                total_loss = sum(losses.values())
                losses["loss"] = total_loss

            backward(total_loss, scaler)

However, this approach raises a question in my mind: what happens if the previous batch contains items that are quite similar to the items in the current batch? Wouldn't this affect the efficacy of the contrastive loss?

I would greatly appreciate any insights or clarifications on these points.

Thank you in advance!!

rwightman commented 2 months ago

@doramasma be careful not to confuse triplet loss with contrastive. Triplet losses require an anchor in addition to postive and negative. Here we just have positive and negative pairings based on the image-text pairs in the dataset. And yes the batching does impact performance quite significantly for various reasons. As long as most of the time the positive pairs are true positives, and the negatives in the batch don't have too many false negatives (I feel this is a bigger confounder than worrying about the case you descripe across batches) there should be enough signal to learn from.

As for what's efficient, there have been loads of papers suggesting alternative forumulations trying to address issues/efficiency with the 'basic' formulation of the loss, making it more sample efficient, doing better with smaller global batch sizes, etc... but simple scales and many of those ideas weren't worth the overhead or just didn't perform at scale.