HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
2.98k stars 525 forks source link

NaN loss for no positive pair from SupCon #141

Closed srsawant34 closed 5 months ago

srsawant34 commented 7 months ago

The SupCon losses throw NaN values for input where we don't have any positive pair for an anchor.

For example:

from losses import SupConLoss
import torch
import torch.nn.functional as F

# define loss with a temperature `temp`
temp = 0.8
criterion = SupConLoss(temperature=temp)

# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = torch.randn(4,1,3)
features = F.normalize(features, p=2, dim=1)
# labels: [bsz]
labels = torch.tensor([0,1,1,3], dtype=float)

# SupContrast
loss = criterion(features, labels)

Outputs:

tensor([    nan, 36.9528, 36.9528,     nan])