UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.38k stars 2.49k forks source link

`KeyError: 'eval_loss'` when `load_best_model_at_end=True` with multiple eval datasets #2749

Open LakeYin opened 5 months ago

LakeYin commented 5 months ago

System Information

Linux x86-64 Python 3.10.5 sentence_transformers 3.0.1 transformers 4.41.2 datasets 2.19.2

Reproduction

Running on GPU:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainingArguments, SentenceTransformerTrainer 
from sentence_transformers.losses import CoSENTLoss, MultipleNegativesRankingLoss

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

stsb_pair_score_train = load_dataset("sentence-transformers/stsb", split="train[:10000]")
quora_pair_train = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:10000]")

train_dataset = {
    "stsb": stsb_pair_score_train,
    "quora": quora_pair_train
}

stsb_pair_score_dev = load_dataset("sentence-transformers/stsb", split="validation")
quora_pair_dev = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[10000:11000]")

eval_dataset = {
    "stsb": stsb_pair_score_dev,
    "quora": quora_pair_dev
}

mnrl_loss = MultipleNegativesRankingLoss(model)
cosent_loss = CoSENTLoss(model)

losses = {
    "stsb": cosent_loss,
    "quora": mnrl_loss
}

args = SentenceTransformerTrainingArguments(
    output_dir="test_model",
    evaluation_strategy="steps",
    load_best_model_at_end=True
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=losses
)
trainer.train()

Error

  File ".../test_sbert.py", line 44, in <module>
    trainer.train()
  File ".../python3.10/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
  File ".../python3.10/site-packages/transformers/trainer.py", line 2291, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
  File ".../python3.10/site-packages/transformers/trainer.py", line 2732, in _maybe_log_save_evaluate
    self._save_checkpoint(model, trial, metrics=metrics)
  File ".../python3.10/site-packages/transformers/trainer.py", line 2824, in _save_checkpoint
    metric_value = metrics[metric_to_check]
KeyError: 'eval_loss'

Notably, this error does not happen when load_best_model_at_end=True is removed from args.

ganeshkrishnan1 commented 5 months ago

This error is thrown even for single datasets

bely66 commented 5 months ago

Same behavior with me

tomaarsen commented 5 months ago

Hello!

Apologies for the delay, I was on vacation last week. When you are using load_best_model_at_end=True, then it will use metric_for_best_model and greater_is_better to figure out whether a certain evaluation score is better or worse than some other. When you're using a DatasetDict, then your evaluation loss will be e.g. eval_stsb_loss or eval_quora_loss, instead of the default eval_loss. In fact, the default eval_loss won't exist, and you'll get this error. I've just created a pull request on transformers to give more useful errors here, indicating that you can use metric_for_best_model in the SentenceTransformerTrainingArguments to specify which value you'd like to use.

If you're not using load_best_model_at_end=True, then it won't need to check which checkpoint is better than another, and so it'll not give this crash. That's why you only get the error with load_best_model_at_end=True.

I'm not sure why this error occurs with single datasets, I haven't been able to reproduce that yet.

imrankh46 commented 5 months ago

@LakeYin @tomaarsen is the loss same in this image that you are using. Also how to use custom loss in sentence transformer.?

Screenshot_20240624-170023

tomaarsen commented 5 months ago

I think you hit sent a bit too quickly, before the image could be added to the comment correctly @imrankh46.

imrankh46 commented 5 months ago

I think you hit sent a bit too quickly, before the image could be added to the comment correctly @imrankh46.

I just edit now you can see the image. Officially I try to chat with but not able.

I implement the loss which used by Alibaba team in the gte quen2 instruct model. But not sure how to use or add the custom loss into sentence transformer.

Next can we load peft/lora model using sentence transformer?

imrankh46 commented 5 months ago

@tomaarsen Screenshot_20240624-174157

tomaarsen commented 5 months ago

Answered your first question in #2774

Next can we load peft/lora model using sentence transformer?

Yes, this is possible, but only with a bit of a hacky workaround (for now). Learn more about it here: https://github.com/UKPLab/sentence-transformers/issues/2748#issuecomment-2173422897