KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.96k stars 656 forks source link

arguments names are different in forward in when using the DDP wrapper #701

Closed elisim closed 2 months ago

elisim commented 2 months ago

Hi, Thank you for the great repository. I find it very useful.

When using named arguments in the loss function, wrapping the loss with DDP results in a failure and produces the following error:

TypeError: DistributedLossWrapper.forward() got an unexpected keyword argument 'embeddings'

Minimal example to reproduce:

import torch
from pytorch_metric_learning.losses import CosFaceLoss

embeddings = torch.randn(3, 16)
labels = torch.LongTensor([0, 1, 2])
loss_fn = CosFaceLoss(num_classes=3, embedding_size=16)

loss = loss_fn(embeddings=embeddings, labels=labels)  # works 
print(loss)

And with the DistributedLossWrapper:

import torch
from pytorch_metric_learning.losses import CosFaceLoss
+ from pytorch_metric_learning.utils import distributed as pml_dist

embeddings = torch.randn(3, 16)
labels = torch.LongTensor([0, 1, 2])
loss_fn = CosFaceLoss(num_classes=3, embedding_size=16)
+ loss_fn = pml_dist.DistributedLossWrapper(loss_fn)

loss = loss_fn(embeddings=embeddings, labels=labels) # error
print(loss)

This behavior is somewhat inconvenient.

KevinMusgrave commented 2 months ago

Fixed in version 2.6.0:

pip install pytorch-metric-learning==2.6.0