KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.98k stars 658 forks source link

How to transfer NTXent loss for segmentation task? #229

Closed Kyfafyd closed 1 year ago

KevinMusgrave commented 3 years ago

I'm guessing you have an embedding and a label for each pixel in the image. You can pass all of these embeddings to NTXentLoss:

from pytorch_metric_learning.losses import NTXentLoss
loss_fn = NTXentLoss()
pixel_loss = loss_fn(embeddings, labels) 

However, since there are so many pixels in an image, you will probably run out of memory. So you can try randomly sampling a reasonable number of triplets, and passing those into the loss function.

from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
indices_tuple = lmu.get_random_triplet_indices(labels, t_per_anchor = 1)
pixel_loss = loss_fn(embeddings, labels, indices_tuple)

t_per_anchor means triplets per anchor. So the larger you make that, the higher the memory consumption will be.

Kyfafyd commented 3 years ago

Thanks for your comment! yes, I have an embedding and a label for each pixel in the image. Meanwhile, total classes is 5. batch_size is 4 the embedding shape is torch.Size([4, 32, 384, 384]), the label shape is torch.Size([4, 384, 384]). But I got an error like this:

Traceback (most recent call last): File "train.py", line 198, in main() File "train.py", line 174, in main train_loss, train_dices = train(model, train_loader, optimizer, LOSS_FUNC, lr_sheduler, device) File "train.py", line 51, in train contrastive_loss = contrastive_loss_func(contrastive_feature, label) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/losses/base_metric_loss_function.py", line 32, in forward loss_dict = self.compute_loss(embeddings, labels, indices_tuple) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/losses/generic_pair_loss.py", line 14, in compute_loss indices_tuple = lmu.convert_to_pairs(indices_tuple, labels) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 57, in convert_to_pairs return get_all_pairs_indices(labels) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 41, in get_all_pairs_indices matches.filldiagonal(0) RuntimeError: all dimensions of input must be of equal length

I am wondering how to solve this?

KevinMusgrave commented 3 years ago

You need to reshape the embeddings to have shape (N, D), and labels to have shape (N,).

Something like this might work, though I haven't confirmed that the reshaping of embeddings matches the reshaping of labels.

embeddings = embeddings.permute(0,2,3,1)
embeddings = embeddings.contiguous().view(-1, 32)
labels = labels.view(-1)
Kyfafyd commented 3 years ago

Thanks very much for help! As you have said, I run out of memory. but I meet this problem when trying randomly sampling a reasonable number of triplets.

Traceback (most recent call last): File "train.py", line 212, in main() File "train.py", line 186, in main train_loss, train_dices = train(model, train_loader, optimizer, LOSS_FUNC, lr_sheduler, device) File "train.py", line 51, in train indices_tuple = lmu.get_random_triplet_indices(label, t_per_anchor=1) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 119, in get_random_triplet_indices pinds = pinds[~torch.eye(n_a).bool()].view((n_a, n_a - 1)) RuntimeError: [enforce fail at CPUAllocator.cpp:64] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 1322049238416 bytes. Error code 12 (Cannot allocate memory)

my code is:

contrastive_loss_func = NTXentLoss(temperature=0.1)
contrastive_feature = contrastive_feature.permute(0, 2, 3, 1)
contrastive_feature = contrastive_feature.contiguous().view(-1, 32)
label = label.view(-1)
indices_tuple = lmu.get_random_triplet_indices(label, t_per_anchor=1)
contrastive_loss = contrastive_loss_func(contrastive_feature, label, indices_tuple)
KevinMusgrave commented 3 years ago

Hmm I see, because the batch size is huge (589000), that function isn't able to create the necessary matrices.

I'll have to think about how to solve this large-batch problem. In the meantime, I think the only workaround would be to randomly sample pixels, to reduce the batch size.

Kyfafyd commented 3 years ago

Thanks a lot! I have sampled pixels randomly for training. But it seems not work. Looking forward to your repo update!

Kyfafyd commented 3 years ago

Hi dear author, have this issue updated recently?

KevinMusgrave commented 3 years ago

Sorry, I haven't gotten around to this yet.