The SupCon losses throw NaN values for input where we don't have any positive pair for an anchor.
For example:
from losses import SupConLoss
import torch
import torch.nn.functional as F
# define loss with a temperature `temp`
temp = 0.8
criterion = SupConLoss(temperature=temp)
# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = torch.randn(4,1,3)
features = F.normalize(features, p=2, dim=1)
# labels: [bsz]
labels = torch.tensor([0,1,1,3], dtype=float)
# SupContrast
loss = criterion(features, labels)
The SupCon losses throw
NaN
values for input where we don't have any positive pair for an anchor.For example:
Outputs: