UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.27k stars 2.47k forks source link

Can custom loss function be used for model.fit? #2562

Open shangh1 opened 7 months ago

shangh1 commented 7 months ago

Hi, I'm using this git to run a cross encoder model, hope to use a customized function. The code is like below. The thing is that using model.fit, it does not specify which is the output. In future, I'll also add another argument in this customized loss to compute lambdarank loss. Thanks!

def my_loss(output, label):
    loss = ... 
    return loss

model.fit(train_dataloader=train_dataloader,
              loss_fct = myloss)
tomaarsen commented 7 months ago

Hello!

Yes, this is totally possible. If you provide your loss function this way, then it'll be called right here: https://github.com/UKPLab/sentence-transformers/blob/737353354fbdf1a419eee864f998ffe9fdf3b682/sentence_transformers/cross_encoder/CrossEncoder.py#L240 i.e. with logits & labels.

shangh1 commented 7 months ago

Thanks for your quick reply! Got it! So the default arguments are logits and labels.