Yunfan-Li / Contrastive-Clustering

Code for the paper "Contrastive Clustering" (AAAI 2021)
MIT License
297 stars 91 forks source link

Instance Loss output different from NT Xent loss #15

Closed sramakrishnan247 closed 3 years ago

sramakrishnan247 commented 3 years ago

Hey,

I found a difference between the output of the Instance loss implemented vs NT Xent loss taken from SIMCLR(https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py)

Although the functions loss very similar, the outputs seems to be different. Could you please look into it and share your insights?

import torch
import torch.nn as nn
import math

class InstanceLoss(nn.Module):
    def __init__(self, batch_size, temperature, device):
        super(InstanceLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device

        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N))
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        mask = mask.bool()
        return mask

    def forward(self, z_i, z_j):
        N = 2 * self.batch_size
        z = torch.cat((z_i, z_j), dim=0)

        sim = torch.matmul(z, z.T) / self.temperature
        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        negative_samples = sim[self.mask].reshape(N, -1)

        labels = torch.zeros(N).to(positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N

        return loss

class NT_Xent(nn.Module):
    """
    More than inspired from https://github.com/Spijkervet/SimCLR/blob/master/modules/nt_xent.py

    Notes
    =====

    Using this pytorch implementation, you don't actually need to l2-norm the inputs, the results will be
    identical, as shown if you run this file.
    """

    def __init__(self, batch_size, temperature, device):
        super(NT_Xent, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.mask = self.get_correlated_samples_mask()
        self.device = device

        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def forward(self, z_i, z_j):
        """
        We do not sample negative examples explicitly.
        Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
        """

        p1 = torch.cat((z_i, z_j), dim=0)
        sim = self.similarity_f(p1.unsqueeze(1), p1.unsqueeze(0)) / self.temperature

        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(self.batch_size * 2, 1)
        negative_samples = sim[self.mask].reshape(self.batch_size * 2, -1)

        labels = torch.zeros(self.batch_size * 2).to(self.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= 2 * self.batch_size
        return loss

    def get_correlated_samples_mask(self):
        mask = torch.ones((self.batch_size * 2, self.batch_size * 2), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(self.batch_size):
            mask[i, self.batch_size + i] = 0
            mask[self.batch_size + i, i] = 0
        return mask

a, b = torch.rand(8, 12), torch.rand(8, 12)
a_norm, b_norm = torch.nn.functional.normalize(a), torch.nn.functional.normalize(b)
cosine_sim = torch.nn.CosineSimilarity()
instance_loss = InstanceLoss(8, 0.5, "cpu")
ntxent_loss = NT_Xent(8, 0.5, "cpu")
print('Cosine')
print(cosine_sim(a, b))
print(cosine_sim(a_norm, b_norm))
print('NT Xent')
print(ntxent_loss(a, b))
print(ntxent_loss(a_norm, b_norm))
print('Instance')
print(instance_loss(a, b))
print(instance_loss(a_norm, b_norm))

Output:

Cosine tensor([0.6606, 0.7330, 0.7845, 0.8602, 0.6992, 0.8224, 0.7167, 0.7500]) tensor([0.6606, 0.7330, 0.7845, 0.8602, 0.6992, 0.8224, 0.7167, 0.7500])

NT Xent tensor(2.7081) tensor(2.7081)

Instance tensor(3.1286) tensor(2.7081)

As you can see, Instance loss gives different results where as the others don't when fed a_norm and b_norm.

Colab notebook: https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py

Yunfan-Li commented 3 years ago

Hi,

It is because the normalization should be conducted at each row instead of the entire feature matrix. In our instance loss, we assume the input feature is normalized to have a length of 1 for each instance so that the inner product could be directly used to compute the cosine similarity. However, this hypothesis could be violated when the input is normalized across the entire matrix. As for the NT Xent loss, it uses the CosineSimilarity metric implemented by PyTorch, in which the features would be normalized at each instance.

In short, I think these two losses would give the same output when you change the a_norm and b_norm to torch.nn.functional.normalize(a, dim=1) and torch.nn.functional.normalize(b, dim=1).

sramakrishnan247 commented 3 years ago

Thanks for the clarification.

sramakrishnan247 commented 3 years ago

It still doesn't seem to give the same output.

Code:

a, b = torch.rand(8, 12), torch.rand(8, 12)
a_norm, b_norm = torch.nn.functional.normalize(a, dim=1), torch.nn.functional.normalize(b, dim=1)

cosine_sim = torch.nn.CosineSimilarity()
instance_loss = InstanceLoss(8, 0.5, "cpu")
ntxent_loss = NT_Xent(8, 0.5, "cpu")
print('Cosine')
print(cosine_sim(a, b))
print(cosine_sim(a_norm, b_norm))
print('NT Xent')
print(ntxent_loss(a, b))
print(ntxent_loss(a_norm, b_norm))
print('Instance')
print(instance_loss(a, b))
print(instance_loss(a_norm, b_norm))

Output:

Cosine
tensor([0.6642, 0.7078, 0.7042, 0.8050, 0.7898, 0.8841, 0.7750, 0.7800])
tensor([0.6642, 0.7078, 0.7042, 0.8050, 0.7898, 0.8841, 0.7750, 0.7800])
NT Xent
tensor(2.7333)
tensor(2.7333)
Instance
tensor(3.1888)
tensor(2.7333)

Also, what exactly is mask_correlated_samples trying to achieve?

Yunfan-Li commented 3 years ago

Sorry, I didn't make it clear. The instance_loss(a, b) and instance_loss(a_norm, b_norm) are not necessary to give the same output unless the input is normalized across instances (i.e., dim=1). Thus, it is necessary to first normalize the input before feeding it into the InstanceLoss. As can be seen, this gives the same output as NT-Xent loss. If you want the output of instance_loss(a, b) and instance_loss(a_norm, b_norm) to be equal, perhaps you need to add the normalization at the beginning of the forward function (e.g., z_i = torch.nn.functional.normalize(z_i, dim=1) and z_j = torch.nn.functional.normalize(z_j, dim=1).

Yunfan-Li commented 3 years ago

mask_correlated_samples is used to mask positive samples (i.e., those belongs to the same instance), for the ease of using cross-entropy loss to compute the infoNCE loss.

sramakrishnan247 commented 3 years ago

Thanks for further clarification.