Closed dangvansam closed 2 years ago
Hi. This kind of input won't work, because the shapes don't match what the function expects.
The shapes should be: anchors : (N, D)
, positives : (N, D)
, negatives: (N, M, D)
.
If you combine the dimensions with 8 and 20, then it should work with InfoNCE(negative_mode='paired')
, but I'm not sure that is what you need. Let me know if it works for you.
anchors = torch.randn(8, 20, 128)
positives = torch.randn(8, 20, 128)
negatives = torch.randn(100, 8, 20, 128)
loss = InfoNCE(negative_mode='paired')
anchors = anchors.reshape(-1, 128)
positives = positives.reshape(-1, 128)
negatives = negatives.reshape(100, -1, 128).transpose(0,1)
output = loss(anchors, positives, negatives)
i have inputs are: anchors: torch.Size([8, 20, 128]) positives: torch.Size([8, 20, 128]) negatives: torch.Size([100, 8, 20, 128])
8-batch size, 20-num pairs, 128-embedding dim, 100-num negative samples (shuffled from positive samples)
show, can i calculate similarity for 3d inputs with your code? thanks