moein-shariatnia / OpenAI-CLIP

Simple implementation of OpenAI CLIP model in PyTorch.
MIT License
574 stars 85 forks source link

Reason for custom Cross Entropy Loss Function #18

Closed ItzHaad closed 1 week ago

ItzHaad commented 3 months ago

Hi, I was looking at your code and wanted to ask that why didn't you use the torch.nn.CrossEntropyLoss() and used your own implementation instead? Since, Pytorch also supports the targets matrix to be class probabilities as well. And would be more stable. Or is there any significance to this line, that i might not be understanding correctly? log_softmax = nn.LogSoftmax(dim=-1) loss = (-targets * log_softmax(preds)).sum(1)

Also, any particular reason why in the forward pass reduction is set to None rather than mean?

BTW, that's quite a nice trick for handling multi-labels.

moein-shariatnia commented 3 months ago

Hi,

Yes, you're right. In the README section after the loss definition, I explain this:

Here's why I didn't use a simpler approach: I need to admit that there's a simpler way to calculate this loss in PyTorch; by doing this: nn.CrossEntropyLoss()(logits, torch.arange(batch_size)). Why I did not use it here? ...

ItzHaad commented 3 months ago

1) I already understand that logic behind not using torch.arange. My question is much simpler, i was just asking why was there a need for a custom CrossEntropyLoss for your code. Is there something different happening from the standard CrossEntropyLoss that pytorch offers?

So, in:

texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()

Can i just do this instead:

texts_loss = torch.nn.CrossEntropyLoss(logits, targets, reduction='none')
images_loss = torch.nn.CrossEntropyLoss(logits.T, targets.T, reduction='none')
loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()

2) Another question, any particular reason why in the cross_entropy_loss, reduction is set to None rather than mean?"

litingfeng commented 1 week ago

1) I already understand that logic behind not using torch.arange. My question is much simpler, i was just asking why was there a need for a custom CrossEntropyLoss for your code. Is there something different happening from the standard CrossEntropyLoss that pytorch offers?

So, in:

texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()

Can i just do this instead:

texts_loss = torch.nn.CrossEntropyLoss(logits, targets, reduction='none')
images_loss = torch.nn.CrossEntropyLoss(logits.T, targets.T, reduction='none')
loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()

2) Another question, any particular reason why in the cross_entropy_loss, reduction is set to None rather than mean?"

I did a comparison using the two implementations and it turned out that _crossentropy and torch.nn.CrossEntropyLoss is NOT equivalent. The following function is equivalent:

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    log_probs = log_softmax(preds)
    # Manually creating one-hot encoded targets
    one_hot_targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
    loss = (-one_hot_targets * log_probs).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()