imbue-ai / self_supervised

A Pytorch-Lightning implementation of self-supervised algorithms
MIT License
536 stars 52 forks source link

BYOL Loss Function is Wrong #7

Closed JosephDiPalma closed 3 years ago

JosephDiPalma commented 3 years ago

The loss function for the BYOL model seems to be wrong. Below is a code snippet to demonstrate this issue.

import torch

# Settings for BYOL.
use_negative_examples_from_batch = False
use_negative_examples_from_queue = False
loss_type = "ip"
batch_size = 128
output_dim = 4096

def incorrect_loss(q, k):
    # _get_contrastive_predictions() method
    if use_negative_examples_from_batch:
        logits = torch.mm(q, k.T)
        labels = torch.arange(0, q.shape[0], dtype=torch.long).to(logits.device)
        return logits, labels

    # compute logits
    # Einstein sum is more intuitive
    # positive logits: Nx1
    l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)

    if use_negative_examples_from_queue:
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])
        logits = torch.cat([l_pos, l_neg], dim=1)
    else:
        logits = l_pos

    # labels: positive key indicators
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)

    # _get_contrastive_loss() method
    if loss_type == "ce":
        if use_eqco_margin:
            if use_negative_examples_from_batch:
                neg_factor = eqco_alpha / batch_size
            elif use_negative_examples_from_queue:
                neg_factor = eqco_alpha / K
            else:
                raise Exception("Must have negative examples for ce loss")

            predictions = utils.log_softmax_with_factors(logits / T, neg_factor=neg_factor)
            return F.nll_loss(predictions, labels)
        return F.cross_entropy(logits / T, labels)

    new_labels = torch.zeros_like(logits)
    new_labels.scatter_(1, labels.unsqueeze(1), 1)

    if loss_type == "bce":
        return F.binary_cross_entropy_with_logits(logits / T, new_labels) * logits.shape[1]

    if loss_type == "ip":
        # inner product
        # negative sign for label=1 (maximize ip), positive sign for label=0 (minimize ip)
        inner_product = (1 - new_labels * 2) * logits
        return torch.mean((inner_product + 1).sum(dim=-1))

    raise NotImplementedError(f"Loss function {self.loss_type} not implemented")

def correct_loss(q, k):
    return 2 - 2 * (q * k).sum() / batch_size

def correct_loss_alt_formula(q, k):
    return ((q - k) ** 2).sum(dim=-1).mean()

if __name__ == "__main__":
    q = torch.randn(batch_size, output_dim)
    k = torch.randn(batch_size, output_dim)

    q = torch.nn.functional.normalize(q, dim=1)
    k = torch.nn.functional.normalize(k, dim=1)

    i_loss = incorrect_loss(q, k)
    c_loss = correct_loss(q, k)
    ca_loss = correct_loss_alt_formula(q, k)

    print(i_loss - c_loss)
    print(i_loss - ca_loss)
    print(c_loss - ca_loss)
abefetterman commented 3 years ago

Thanks for the great reproduction! It turns out to be an error of documentation. Our "inner product" loss function is half of the loss function in the BYOL paper. This will only modify the optimal learning rate for SGD.

In order to more precisely mimic the loss function of the BYOL paper, we have added a loss_constant_factor to the hyperparameters. Set this to 2 to reproduce the paper loss function. We have added this suggestion to our BYOL section on the README.

Edit for comment: Why is this the loss we chose? The inner product loss used in our code corresponds to the standard InfoNCE loss without the softmax function. We can more easily compare between implementations when we remove the factor of two that has been added in the BYOL paper. This factor is most visible in their Eq 3, which is twice the usual InfoNCE loss.