SeanLee97 / AnglE

Train and Infer Powerful Sentence Embeddings with AnglE | 🔥 SOTA on STS and MTEB Leaderboard
https://arxiv.org/abs/2309.12871
MIT License
493 stars 33 forks source link

model.fit() Validation Loss not showing #98

Closed NSC508 closed 1 month ago

NSC508 commented 2 months ago

Hello Author, thank you so much for publishing this library - it's incredibly useful. I ran into an issue while finetuning the model for my downstream task, and was wondering if I am using the function wrong. I am following the documentation for custom training and the training runs as normal, but I would like to view the validation loss as well as train loss so that I can identify potential overfitting. However, the function only seems to be logging train loss. Is there any way for me to log the validation loss as well during training?

SeanLee97 commented 2 months ago

Hi @NSC508,

Currently, angle_emb only outputs Spearman's correlation when a validation set is provided to the fit() function. We will consider adding evaluation loss output in the next version.

Thank you!

NSC508 commented 2 months ago

@SeanLee97 Thank you for the quick response! I see, understood. Follow up question in the meantime then - is there a way for me to manually calculate loss for a given dataset of format Dataset.A? My idea was to use the AngleDataCollator to tokenize my dataset, then pass it through the model backbone to get the loss. However, this does not work

collator = AngleDataCollator(angle.tokenizer, padding=True, return_tensors='pt', max_length=angle.max_length, filter_duplicate=True, coword_random_mask_rate=0

outputs = angle.backbone(collator(eval_ds)['input_ids].to('cuda'), labels=torch.tensor(eval_ds['label']).to('cuda')

SeanLee97 commented 1 month ago

hi @NSC508 , the evaluation loss has been supported in the lastest angle-emb. You can upgrade the angle-emb via pip install -U angle-emb to use this feature. To output evaluation loss, --valid_name_or_path should be specified. Its dataset format should be consistent with --train_name_or_path. Previous validation for callback has been renamed to --valid_name_or_path_for_callback.

Related PR: https://github.com/SeanLee97/AnglE/pull/100