KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
6k stars 658 forks source link

VICReg loss ? #372

Closed ralgayres closed 2 years ago

ralgayres commented 3 years ago

Hi, will there be an update with VICReg loss function? It is a pretty cool loss that does not need negative samples. https://arxiv.org/abs/2105.04906

KevinMusgrave commented 3 years ago

Thanks for the suggestion

ralgayres commented 3 years ago

Actually, this one also does it. Both losses are closely related https://arxiv.org/pdf/2103.03230.pdf

cwkeam commented 3 years ago

Since VICRegLoss is really just about adding regularization based on the embedding vectors, it seems like defining a loss function with a built-in VICRegularizer could do the job.

But the problem that I think is extremely hard to solve for integration is that the embeddings must come in groups:

BF90E56B-66C2-48C9-9B68-374B408F4C56

So we can't just do:

loss_fn = VICRegLoss()
embeddings = torch.cat((original_embds, augmented_embeds))
labels = torch.cat((labels, labels))
loss_fn(embeddings, labels)

Since the regularization (v(Z) and v(Z') in the diagram) is calculated separately for each original_embds and augmented_embds.

Seems like this requires something like

loss_fn(orig_embeds, aug_embeds, labels)

But it seems like that's an antipattern with the way things work right now.

Could there be another way?

KevinMusgrave commented 3 years ago

Since the regularization (v(Z) and v(Z') in the diagram) is calculated separately for each original_embds and augmented_embds.

Seems like this requires something like

loss_fn(orig_embeds, aug_embeds, labels)

But it seems like that's an antipattern with the way things work right now.

Could there be another way?

Option 1: Wait for other branch to merge

I have a branch in progress which adds optional ref_emb and ref_labels arguments to BaseMetricLossFunction. The ref inputs are for positives and negatives, while the embeddings input is for anchors. (If the refs are not provided, then it works like the current version of code.) It could also be added to BaseRegularizer. This might simplify the implementation of VICRegLoss or VICRegularizer, because embeddings could represent Z, and ref_emb could represent Z'. So we could wait for feat-losses-with-ref branch to be merged before tackling this.

Option 2: Make it work with current code

Alternatively for the current codebase, we could require that the embeddings are packaged a certain way, i.e. the first half of the batch is Z, and second half is Z'. An implementation of VICRegLoss could do various checks like

# each label in the first half should be unique
len(torch.unique(labels[:halfway]))) == len(labels[:halfway])

and

# 1st half and 2nd half should have same labels
torch.isequal(labels[:halfway], labels[halfway:])

An implementation of VICRegularizer wouldn't be able to check anything because regularizers are only given embeddings (no labels).

KevinMusgrave commented 2 years ago

Option 1: Wait for other branch to merge

That branch is now merged

cwkeam commented 2 years ago

@KevinMusgrave Right, the implementation seems pretty intuitive with the new merge. Happy to submit a pull request if nobody's working on it right now.

KevinMusgrave commented 2 years ago

@codeandproduce Yes please do!

cwkeam commented 2 years ago

@KevinMusgrave

The function itself is easily done, but I'm not sure if this can fit at all into the current BaseMetricLossFunction framework.


class VICRegLoss(BaseMetricLossFunction):
    def __init__(
        self,
        invariance_lambda=25,
        variance_mu=25,
        covariance_v=1,
        eps=1e-4,
        **kwargs
    ):
        super().__init__(**kwargs)
        """
        The overall loss function is a weighted average of the invariance, variance and covariance terms:
            L(Z, Z') = λs(Z, Z') + µ[v(Z) + v(Z')] + ν[c(Z) + c(Z')],
        where λ, µ and ν are hyper-parameters controlling the importance of each term in the loss.
        """
        self.invariance_lambda = invariance_lambda
        self.variance_mu = variance_mu
        self.covariance_v = covariance_v
        self.eps = eps

    def compute_loss(
        self, embeddings, ref_emb
    ):

        invariance_loss = self.invariance_loss(embeddings, ref_emb)
        variance_loss = self.variance_loss(embeddings, ref_emb)
        covariance_loss = self.covariance_loss(embeddings, ref_emb)

        loss = self.invariance_lambda * invariance_loss + self.variance_mu * variance_loss + self.covariance_v * covariance_loss
        return loss

# ...implementation of the above individual loss functions

The idea is that if you have a batch of images I and some augmented version of those images I' (e.g. a flip) and your model embeds these images E(I) = Z,

loss_fn = VICRegLoss()

Z = E(I)
Z' = E(I')

loss = loss_fn(Z, Z')

It doesn't use any notion of labels or indices_tuple.

I was considering forcing this framework with something like this:

# assume Z.size() == (5, 255)
embeds = Z
ref_embeds = Z'
labels = [1,2,3,4,5]

loss = loss_fn(embeds, ref_embeds, labels, ref_labels=labels)

but this is very unnecessary.. it's not like this function could benefit from handling cases where ref_labels is not exactly the same as labels.

But the specific thing making my current implementation impossible is the following forward() function in the BaseMetricLossFunction class:

loss_dict = self.compute_loss(
    embeddings, labels, indices_tuple, ref_emb, ref_labels
)
self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings)
return self.reducer(loss_dict, embeddings, labels)

since it expects a dict that looks something like:

 return {
    "loss": {
        "losses": loss,
        "indices": indices_tuple,
        "reduction_type": "triplet",
    }
}

and VICRegLoss doesn't use any notion of indices or reduction.

I'm wondering if I should just write a new Base class with something like BaseSimpleLossFunction that just directly returns the loss of the compute_loss function of the child class with the purpose of being the most bare-bones and flexible superclass.

class BaseSimpleLossFunction:
    def compute_loss(self, embeddings, labels, indices_tuple=None):
        """
        This has to be implemented and is what actually computes the loss.
        """
        raise NotImplementedError

    def forward(
        self, embeddings=None, labels=None, indices_tuple=None, ref_emb=None, ref_labels=None
    ):
        loss = self.compute_loss(
            embeddings, labels, indices_tuple, ref_emb, ref_labels
        )
        return loss

    def zero_loss(self):
        return {"losses": 0, "indices": None, "reduction_type": "already_reduced"}

    def zero_losses(self):
        return {loss_name: self.zero_loss() for loss_name in self.sub_loss_names()}

    def _sub_loss_names(self):
        return ["loss"]

    def sub_loss_names(self):
        return self._sub_loss_names() + self.all_regularization_loss_names()

On another note, maybe it's also relevant to consider how:

loss = self.invariance_lambda * invariance_loss + self.variance_mu * variance_loss + self.covariance_v * covariance_loss

This is basically the same thing as:

vic_regularization = self.variance_mu * variance_loss + self.covariance_v * covariance_loss
loss = mse(Z, Z') + vic_regularization
return loss

from the paper: • Invariance: the mean square distance between the embedding vectors. • Variance: a hinge loss to maintain the standard deviation (over a batch) of each variable of the embedding above a given threshold. This term forces the embedding vectors of samples within a batch to be different. • Covariance: a term that attracts the covariances (over a batch) between every pair of (centered) embedding variables towards zero. This term decorrelates the variables of each embedding and prevents an informational collapse in which the variables would vary together or be highly correlated.

Does this help form a solution at all?

KevinMusgrave commented 2 years ago

I guess BaseRegularizer is already kind of like your BaseSimpleLossFunction. But it is missing a ref argument from its forward function.

For now, you could try overriding the forward function and see if the design is reasonable:

from .base_regularizer import BaseRegularizer
from ..utils import common_functions as c_f

class VICRegLoss(BaseRegularizer):
    def forward(self, x, y):
        self.reset_stats()
        loss_dict = self.compute_loss(x, y)
        # Don't worry about x being passed into reducer here.
        # I don't think it's even used in the currently available reducers.
        return self.reducer(loss_dict, x, c_f.torch_arange_from_size(x))

    def compute_loss(self, x, y):
        invariance_loss = self.invariance_loss(embeddings, ref_emb)
        variance_loss = self.variance_loss(embeddings, ref_emb)
        covariance_loss = self.covariance_loss(embeddings, ref_emb)
        loss = self.invariance_lambda * invariance_loss + self.variance_mu * variance_loss + self.covariance_v * covariance_loss
        return {
            "loss": {
                "losses": loss,
                "indices": None,
                "reduction_type": "already_reduced",
            }
        }

It might be possible to keep the losses unreduced. For example, the mean square distance between vectors can be kept as per-pair (i.e. not passed into torch.mean).

The "losses", "indices", and "reduction_type" keys are a bit confusing. I think the reducers are overly complicated in general. Anyway, if you decide to keep the losses unreduced, you can return something like this:

        return {
            "invariance_loss": {
                "losses": invariance_loss,
                "indices": c_f.torch_arange_from_size(invariance_loss),
                "reduction_type": "element",
            },
            # Repeat for the other 2 losses,
            # They can be "already_reduced" if there is no good way of keeping them unreduced
        }

# also override this function
    def sub_loss_names(self):
        return ["invariance_loss", ...]

But to keep things simple, I would start with the "already_reduced" version.

KevinMusgrave commented 2 years ago

Re: the components of the VIC loss, you could write them as simple loss functions (i.e. extend torch.nn.Module), and then use them inside of VICRegLoss.

KevinMusgrave commented 2 years ago

Ultimately, if you find it makes more sense to extend torch.nn.Module, then that's ok too.

KevinMusgrave commented 2 years ago

Available in v1.1:

pip install pytorch-metric-learning==1.1.0