KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.87k stars 654 forks source link

Stratified Sampler with NTXent + Cross-Batch -- Cuda Error #604

Closed sapan closed 1 year ago

sapan commented 1 year ago

Thanks Kevin for this amazing library.

The data I am using for contrastive learning is skewed (28 classes with 1 class occupying about 35% of data). I have written a custom sampler that creates a batch of examples keeping the data distribution in mind (Stratified Sampler). The other settings I am trying include: batch size: 32/64; loss=NTXent; Cross-batch memory (1024/2048)

I get cuda OOM error even with small batch sizes (using V100 with 32GB GPU). It occurs during loss computation at, denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator (line 34- ntxent_loss.py)

I think this is because 35% of the cross-batch memory is occupied by 1 class. so, |neg_pairs| is large for classes other than the majority class (With MPerClassSampler, we know that |neg_pairs| = mem_size * (1 - 1/num_of_classes))

Is there any way we can handle this kind of situation? For example, we upper bound neg_pairs to some size k so that when |neg_pairs| > k, we trim neg_pairs to contain only k entries. Do you have any suggestions on how to handle this?

KevinMusgrave commented 1 year ago

This might work:

from pytorch_metric_learning.miners import BaseMiner
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu

class LimitedNumberOfPairs(BaseMiner):
    def __init__(self, max_neg, **kwargs):
        super().__init__(**kwargs)
        self.max_neg = max_neg

    def mine(self, embeddings, labels, ref_emb, ref_labels):
        a1, p, a2, n = lmu.get_all_pairs_indices(labels, ref_labels)
        a2, n = a2[:self.max_neg], n[:self.max_neg]
        return a1, p, a2, n

loss_fn = NTXentLoss()
embedding_size = 512
miner = LimitedNumberOfPairs(k)

loss_fn = CrossBatchMemory(loss_fn, embedding_size, miner)
sapan commented 1 year ago

Thanks a lot Kevin. It solved my problem. Just that I couldn't reply there..

On Tue, Apr 11, 2023, 2:20 AM Kevin Musgrave @.***> wrote:

Closed #604 https://github.com/KevinMusgrave/pytorch-metric-learning/issues/604 as completed.

— Reply to this email directly, view it on GitHub https://github.com/KevinMusgrave/pytorch-metric-learning/issues/604#event-8967044556, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACAPPDPGI467ISXMKMEBGDXARXCTANCNFSM6AAAAAAWSD2S6A . You are receiving this because you authored the thread.Message ID: <KevinMusgrave/pytorch-metric-learning/issue/604/issue_event/8967044556@ github.com>