UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
14.86k stars 2.44k forks source link

Adding graph embeddings as another feature to fine-tune SBERT model #1223

Open Nicolabo opened 2 years ago

Nicolabo commented 2 years ago

I am researching the example of using SBERT for Information Retrieval in Quora Duplicates Questions dataset. I am wondering how to proceed to apply other embedding-like features such as graph embeddings (represented as a single fixed-length list of embeddings).

My idea is to modify MultipleNegativesRankingLoss.forward function and allow to not only provide sentence_features as an input but also another_feature_embeddings which might be concatenated with reps variable. Something like that:

def _concatenate_features(self, sentence_feature, other_feature):
    feature1 = self.model(sentence_feature)['sentence_embedding']
    return torch.cat([feature1, other_feature], dim=1)

def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor, other_features):
    reps = [self._concatenate_features(sf, of) for sf, of in zip(features, other_features)]
    embeddings_a = reps[0]
    embeddings_b = torch.cat(reps[1:])

    scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
    labels = torch.tensor(range(len(scores)), dtype=torch.long,
                          device=scores.device)  # Example a[i] should match with b[i]
    return self.cross_entropy_loss(scores, labels)

It seems like it does not violate this loss function. Do you think such modification make sense? Obviously it might require some changes from the model side.

nreimers commented 2 years ago

Yes, looks good