HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
3.12k stars 537 forks source link

NaN loss for no positive pair from SupCon #141

Closed srsawant34 closed 10 months ago

srsawant34 commented 11 months ago

The SupCon losses throw NaN values for input where we don't have any positive pair for an anchor.

For example:

from losses import SupConLoss
import torch
import torch.nn.functional as F

# define loss with a temperature `temp`
temp = 0.8
criterion = SupConLoss(temperature=temp)

# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = torch.randn(4,1,3)
features = F.normalize(features, p=2, dim=1)
# labels: [bsz]
labels = torch.tensor([0,1,1,3], dtype=float)

# SupContrast
loss = criterion(features, labels)

Outputs:

tensor([    nan, 36.9528, 36.9528,     nan])