RElbers / info-nce-pytorch

PyTorch implementation of the InfoNCE loss for self-supervised learning.
MIT License
481 stars 39 forks source link

How to use InfoNCE for prototypes with shape [b, num_cls, num_tokens, dim]? #19

Open zzzyzh opened 1 month ago

zzzyzh commented 1 month ago

Hi, I’m working with a set of prototypes that have the shape [b, num_cls, num_tokens, dim]. My goal is to use InfoNCE loss to maximize inter-class differences.

I have the following questions:

How should I apply InfoNCE to my prototypes in order to increase the distance between different classes? Should I treat each num_tokens as separate samples for contrastive learning, or is there a better way to structure the loss computation? Any guidance on how to set up the loss function properly for this scenario would be greatly appreciated!

zzzyzh commented 1 month ago

Why is the required shape[0] equal to bs (batch size)?