beichenzbc / Long-CLIP

[ECCV 2024] official code for "Long-CLIP: Unlocking the Long-Text Capability of CLIP"
Apache License 2.0
694 stars 33 forks source link

Question re. Label Smoothing #62

Closed rsomani95 closed 3 months ago

rsomani95 commented 3 months ago

Hello, thank you for sharing your work. I really enjoyed reading your paper.

I noticed that you are doing label smoothing in the loss function: https://github.com/beichenzbc/Long-CLIP/blob/9308e5e5ed2156e2465d490485101069f86dbf26/model/model_longclip.py#L475-L483

but never saw this being called out in the paper. I was curious if you could shed some light on this decision? Did you try an ablation with/without smoothing?

Thanks!

beichenzbc commented 3 months ago

Oh, that's a good question. Label Smoothing is a quite common strategy used in contrastive training to avoid false negative and overfitting. Therefore we didn't mention it in paper and didn't do the ablation.

rsomani95 commented 3 months ago

Ah, interesting. Thanks for the quick response.

I've seen it be used in the classification domain quite often but my canonical reference for contrastive learning has been open_clip and none of their loss functions use label smoothing, thus my original question

Out of curiosity, could you share a paper reference that uses label smoothing? I'm very curious to see ablations regarding that in any context.

beichenzbc commented 3 months ago

alpha-clip (https://arxiv.org/pdf/2312.03818) for example. Although they didn't mention label smoothing in the paper, the use it in the source code (https://github.com/SunzeY/AlphaCLIP)

zer0int commented 3 months ago

@rsomani95 and @beichenzbc - thank you for this most interesting discussion, I am glad I came across it!

I didn't pull the changes @beichenzbc made, but I implemented this suggestion in my custom training code - which includes Geometric Parametrization (GmP) of Linear Layers as well as activation value manipulation of predominantly an 'adverb neuron' in the Vision Transformer during fine-tuning, but: The relative change of GmP-classic (classic contrastive loss) vs. GmP-smooth is only due to the new loss with smoothing. Everything else (dataset, other hyperparameters, ...) was the same.

Indeed, I was able to squeeze out some additional improvement in accuracy for the label smoothing fine-tune. Please note that, for comparison purposes, the dataset labels were capped to <<77 tokens, to fit the original ViT-L/14 as well. I assume it's possible that the full long captions of COCO-SPRIGHT-40k could boost Long-CLIP even more.

longclip-eval-gmp-smooth

gmp-models-extreme-plot-all-evals

As usual, you can find the code to reproduce these exact fine-tunes (GmP-classic and GmP-smooth) on my fork of the Long-CLIP repo (and the CLIP-fine-tune repo for original OpenAI/CLIP, respectively). The models are on huggingface for download - so feel free to probe them with ablations or else @rsomani95 !

Quick overview of modified loss below.

# Classic contrastive loss.

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, logits_per_image, logits_per_text):
        # Normalize the features to avoid overflow or underflow
        logits_per_image = F.normalize(logits_per_image, p=2, dim=1)
        logits_per_text = F.normalize(logits_per_text, p=2, dim=1)

        # Calculate logits
        logits = torch.matmul(logits_per_image, logits_per_text.t()) / self.temperature
        labels = torch.arange(logits.size(0), device=logits.device)

        # Calculate loss as the mean of the two cross-entropy losses
        loss_img = self.criterion(logits, labels)
        loss_txt = self.criterion(logits.t(), labels)

        return (loss_img + loss_txt) / 2

# New Custom Loss with smoothing.

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07, smoothing=0.1):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
        self.smoothing = smoothing

    def forward(self, logits_per_image, logits_per_text):
        # Normalize the features to avoid overflow or underflow
        logits_per_image = F.normalize(logits_per_image, p=2, dim=1)
        logits_per_text = F.normalize(logits_per_text, p=2, dim=1)

        # Calculate logits
        logits = torch.matmul(logits_per_image, logits_per_text.t()) / self.temperature
        labels = torch.arange(logits.size(0), device=logits.device)

        # Apply label smoothing
        N = logits.size(0)
        smoothed_labels = torch.full_like(logits, self.smoothing / (N - 1))
        smoothed_labels.scatter_(1, labels.unsqueeze(1), 1.0 - self.smoothing)

        # Calculate loss manually using log-softmax and smoothed labels
        log_probs = F.log_softmax(logits, dim=1)
        loss_img = -(smoothed_labels * log_probs).sum(dim=1).mean()

        log_probs = F.log_softmax(logits.t(), dim=1)
        loss_txt = -(smoothed_labels * log_probs).sum(dim=1).mean()

        return (loss_img + loss_txt) / 2
beichenzbc commented 3 months ago

Thanks, that's really an interesting finding!

rsomani95 commented 3 months ago

@beichenzbc thanks for the reference

@zer0int thanks for sharing your results. Quite promising to see the bump (tiny, but definitely notable IMO) in both regular CLIP and LongCLIP training!