KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.9k stars 658 forks source link

Combining two loss functions #648

Closed HarryP23 closed 1 year ago

HarryP23 commented 1 year ago

Hello! I'm trying to combine two loss functions, contrastive loss and arcface loss, in training. The new loss add two result values of each loss. Should I make a new class named CombinedLoss or fix a trainer to add two losses?

KevinMusgrave commented 1 year ago

I think MultipleLosses would work: https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#multiplelosses

Example:

from pytorch_metric_learning.losses import ContrastiveLoss, ArcFaceLoss, MultipleLosses

loss1 = ContrastiveLoss()
loss2 = ArcFaceLoss(embedding_size=32, num_classes=10)
loss_fn = MultipleLosses([loss1, loss2])