Closed rsomani95 closed 3 months ago
Oh, that's a good question. Label Smoothing is a quite common strategy used in contrastive training to avoid false negative and overfitting. Therefore we didn't mention it in paper and didn't do the ablation.
Ah, interesting. Thanks for the quick response.
I've seen it be used in the classification domain quite often but my canonical reference for contrastive learning has been open_clip
and none of their loss functions use label smoothing, thus my original question
Out of curiosity, could you share a paper reference that uses label smoothing? I'm very curious to see ablations regarding that in any context.
alpha-clip (https://arxiv.org/pdf/2312.03818) for example. Although they didn't mention label smoothing in the paper, the use it in the source code (https://github.com/SunzeY/AlphaCLIP)
@rsomani95 and @beichenzbc - thank you for this most interesting discussion, I am glad I came across it!
I didn't pull the changes @beichenzbc made, but I implemented this suggestion in my custom training code - which includes Geometric Parametrization (GmP) of Linear Layers as well as activation value manipulation of predominantly an 'adverb neuron' in the Vision Transformer during fine-tuning, but: The relative change of GmP-classic (classic contrastive loss) vs. GmP-smooth is only due to the new loss with smoothing. Everything else (dataset, other hyperparameters, ...) was the same.
Indeed, I was able to squeeze out some additional improvement in accuracy for the label smoothing fine-tune. Please note that, for comparison purposes, the dataset labels were capped to <<77 tokens, to fit the original ViT-L/14 as well. I assume it's possible that the full long captions of COCO-SPRIGHT-40k could boost Long-CLIP even more.
As usual, you can find the code to reproduce these exact fine-tunes (GmP-classic and GmP-smooth) on my fork of the Long-CLIP repo (and the CLIP-fine-tune repo for original OpenAI/CLIP, respectively). The models are on huggingface for download - so feel free to probe them with ablations or else @rsomani95 !
Quick overview of modified loss below.
# Classic contrastive loss.
class ContrastiveLoss(nn.Module):
def __init__(self, temperature=0.07):
super(ContrastiveLoss, self).__init__()
self.temperature = temperature
self.criterion = nn.CrossEntropyLoss()
def forward(self, logits_per_image, logits_per_text):
# Normalize the features to avoid overflow or underflow
logits_per_image = F.normalize(logits_per_image, p=2, dim=1)
logits_per_text = F.normalize(logits_per_text, p=2, dim=1)
# Calculate logits
logits = torch.matmul(logits_per_image, logits_per_text.t()) / self.temperature
labels = torch.arange(logits.size(0), device=logits.device)
# Calculate loss as the mean of the two cross-entropy losses
loss_img = self.criterion(logits, labels)
loss_txt = self.criterion(logits.t(), labels)
return (loss_img + loss_txt) / 2
# New Custom Loss with smoothing.
class ContrastiveLoss(nn.Module):
def __init__(self, temperature=0.07, smoothing=0.1):
super(ContrastiveLoss, self).__init__()
self.temperature = temperature
self.smoothing = smoothing
def forward(self, logits_per_image, logits_per_text):
# Normalize the features to avoid overflow or underflow
logits_per_image = F.normalize(logits_per_image, p=2, dim=1)
logits_per_text = F.normalize(logits_per_text, p=2, dim=1)
# Calculate logits
logits = torch.matmul(logits_per_image, logits_per_text.t()) / self.temperature
labels = torch.arange(logits.size(0), device=logits.device)
# Apply label smoothing
N = logits.size(0)
smoothed_labels = torch.full_like(logits, self.smoothing / (N - 1))
smoothed_labels.scatter_(1, labels.unsqueeze(1), 1.0 - self.smoothing)
# Calculate loss manually using log-softmax and smoothed labels
log_probs = F.log_softmax(logits, dim=1)
loss_img = -(smoothed_labels * log_probs).sum(dim=1).mean()
log_probs = F.log_softmax(logits.t(), dim=1)
loss_txt = -(smoothed_labels * log_probs).sum(dim=1).mean()
return (loss_img + loss_txt) / 2
Thanks, that's really an interesting finding!
@beichenzbc thanks for the reference
@zer0int thanks for sharing your results. Quite promising to see the bump (tiny, but definitely notable IMO) in both regular CLIP and LongCLIP training!
Hello, thank you for sharing your work. I really enjoyed reading your paper.
I noticed that you are doing label smoothing in the loss function: https://github.com/beichenzbc/Long-CLIP/blob/9308e5e5ed2156e2465d490485101069f86dbf26/model/model_longclip.py#L475-L483
but never saw this being called out in the paper. I was curious if you could shed some light on this decision? Did you try an ablation with/without smoothing?
Thanks!