Closed srsawant34 closed 6 months ago
@HobbitLong Kindly look into this pull request. Thanks :)
Hi, @srsawant34 , thanks for this commit.
Why the loss for the second item jumps from 36.9528
to 65.1030
?
@HobbitLong Hi, thanks for your response. In both the scenarios mentioned in the code block, the features tensor is generated random i.e they are not the same. I hope this helps.
@HobbitLong Please allow me to demonstrate, with the same feature tensor.
With previous code:
from losses import SupConLoss
import torch
import torch.nn.functional as F
# define loss with a temperature `temp`
temp = 0.8
criterion = SupConLoss(temperature=temp)
# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = torch.tensor(
[[[-2.1968, -0.8786, -0.9070]],
[[ 0.7557, -0.8486, -0.2785]],
[[-0.5093, 0.0999, 0.5296]],
[[-0.3224, 1.2153, 0.9145]]],
dtype=torch.float32
) # [ 4, 1, 3 ]
features = F.normalize(features, p=2, dim=1)
# labels: [bsz]
labels = torch.tensor([0,1,1,3], dtype=float)
# SupContrast
loss = criterion(features, labels)
Outputs (before mean):
Loss: tensor([ nan, 57.2958, 85.7973, nan])
With new changes:
from losses import SupConLoss
import torch
import torch.nn.functional as F
# define loss with a temperature `temp`
temp = 0.8
criterion = SupConLoss(temperature=temp)
# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = torch.tensor(
[[[-2.1968, -0.8786, -0.9070]],
[[ 0.7557, -0.8486, -0.2785]],
[[-0.5093, 0.0999, 0.5296]],
[[-0.3224, 1.2153, 0.9145]]],
dtype=torch.float32
) # [ 4, 1, 3 ]
features = F.normalize(features, p=2, dim=1)
# labels: [bsz]
labels = torch.tensor([0,1,1,3], dtype=float)
# SupContrast
loss = criterion(features, labels)
Outputs (before mean):
Loss: tensor([-0.0000, 57.2958, 85.7973, -0.0000])
@HobbitLong Kindly look into this, your time is much appreciated. Thanks
@srsawant34, thank you for the contribution. Merged!
This resolves the issue #141.
With previous code:
Outputs (before mean):
With new changes:
Outputs (before mean):