KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.96k stars 656 forks source link

NTXentLoss with sequences #235

Closed asafbenj closed 3 years ago

asafbenj commented 3 years ago

Hi, First, thanks a lot for this awesome contribution! I was wondering whether and how one could use NTXentLoss for sequential data tasks, such as ASR or NLP. Say I'm using a Transformer and my data is a 3D tensor with shape (n_tokens, batch_size, model_dim). Is it possible to use NTXentLoss in this case? I guess one stright-forward way would be to call NTXentLoss for each token separately and then just sum up these losses, but I'm not sure that'd be the most efficient and accurate way (I'm pretty new to most this stuff). Anyway, any advice would be highly appreciated. Thanks again!

KevinMusgrave commented 3 years ago

Thanks for your interest. Here are some related issues that might help:

How to use NTXentLoss as in Contrastive Predictive Coding: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/179 Using BERT as trunk: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/29

Also you could check out the DeCLUTR repo which uses NTXentLoss for NLP. The loss function gets used in this forward function. It is contained inside self._loss.

asafbenj commented 3 years ago

Wow that was super fast and helpful. Thanks a lot!