RElbers / info-nce-pytorch

PyTorch implementation of the InfoNCE loss for self-supervised learning.
MIT License
445 stars 37 forks source link

calculate similarity for 3-dim input #4

Closed dangvansam closed 2 years ago

dangvansam commented 2 years ago

i have inputs are: anchors: torch.Size([8, 20, 128]) positives: torch.Size([8, 20, 128]) negatives: torch.Size([100, 8, 20, 128])

8-batch size, 20-num pairs, 128-embedding dim, 100-num negative samples (shuffled from positive samples)

show, can i calculate similarity for 3d inputs with your code? thanks

RElbers commented 2 years ago

Hi. This kind of input won't work, because the shapes don't match what the function expects. The shapes should be: anchors : (N, D), positives : (N, D), negatives: (N, M, D). If you combine the dimensions with 8 and 20, then it should work with InfoNCE(negative_mode='paired'), but I'm not sure that is what you need. Let me know if it works for you.

    anchors = torch.randn(8, 20, 128)
    positives = torch.randn(8, 20, 128)
    negatives = torch.randn(100, 8, 20, 128)

    loss = InfoNCE(negative_mode='paired')
    anchors = anchors.reshape(-1, 128)
    positives = positives.reshape(-1, 128)
    negatives = negatives.reshape(100, -1, 128).transpose(0,1)
    output = loss(anchors, positives, negatives)