KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
6k stars 658 forks source link

[Question]: CrossBatchMemory pos/neg pair creation #599

Closed DanielRoeder1 closed 1 year ago

DanielRoeder1 commented 1 year ago

I am trying to train a bi-encoder network (non-siamese) to encode chat queries and wiki documents.

Intuitively I would use the NTXentLoss and set the inputs as embeddings and ref_emb. Thus following the guidance of https://github.com/KevinMusgrave/pytorch-metric-learning/issues/549#issuecomment-1312376923 would guarantee that during the loss calculation, only queries are compared to the document.

As CrossBatchMemory takes in embeddings as one concatenated input and assuming that I use the enqueue mask to only write the queries to the buffer this will lead to the query embeddings being compared with themself along with the comparison with the documents. As I understand the get_all_pair_indice function this happens as a result of the query embedding being both part of the buffer labels and input labels.

My question is would this affect training? Intuitively I would have only compared queries to documents. Now I could see how multiple queries that are similar to a document should also be similar to themself but I am not sure whether this has any effect on the two encoder architecture I am using.

Curious and thankful for any suggestions.

KevinMusgrave commented 1 year ago

If you're using enqueue_mask, the embeddings added to the buffer will be removed from the embeddings variable, so there will be no self comparison. You can see in the code that embeddings = embeddings[~enqueue_mask]:

https://github.com/KevinMusgrave/pytorch-metric-learning/blob/691a6354be130547a8e26e170b86cf65c36cd791/src/pytorch_metric_learning/losses/cross_batch_memory.py#L62-L67

DanielRoeder1 commented 1 year ago

Ah perfect I missed this line when going through the code.

Thanks for the quick response!