Closed ralgayres closed 2 years ago
Thanks for the suggestion
Actually, this one also does it. Both losses are closely related https://arxiv.org/pdf/2103.03230.pdf
Since VICRegLoss is really just about adding regularization based on the embedding vectors, it seems like defining a loss function with a built-in VICRegularizer
could do the job.
But the problem that I think is extremely hard to solve for integration is that the embeddings must come in groups:
So we can't just do:
loss_fn = VICRegLoss()
embeddings = torch.cat((original_embds, augmented_embeds))
labels = torch.cat((labels, labels))
loss_fn(embeddings, labels)
Since the regularization (v(Z)
and v(Z')
in the diagram) is calculated separately for each original_embds
and augmented_embds
.
Seems like this requires something like
loss_fn(orig_embeds, aug_embeds, labels)
But it seems like that's an antipattern with the way things work right now.
Could there be another way?
Since the regularization (
v(Z)
andv(Z')
in the diagram) is calculated separately for eachoriginal_embds
andaugmented_embds
.Seems like this requires something like
loss_fn(orig_embeds, aug_embeds, labels)
But it seems like that's an antipattern with the way things work right now.
Could there be another way?
I have a branch in progress which adds optional ref_emb
and ref_labels
arguments to BaseMetricLossFunction
. The ref
inputs are for positives and negatives, while the embeddings
input is for anchors. (If the refs are not provided, then it works like the current version of code.) It could also be added to BaseRegularizer
. This might simplify the implementation of VICRegLoss
or VICRegularizer
, because embeddings
could represent Z, and ref_emb
could represent Z'. So we could wait for feat-losses-with-ref
branch to be merged before tackling this.
Alternatively for the current codebase, we could require that the embeddings are packaged a certain way, i.e. the first half of the batch is Z, and second half is Z'. An implementation of VICRegLoss
could do various checks like
# each label in the first half should be unique
len(torch.unique(labels[:halfway]))) == len(labels[:halfway])
and
# 1st half and 2nd half should have same labels
torch.isequal(labels[:halfway], labels[halfway:])
An implementation of VICRegularizer
wouldn't be able to check anything because regularizers are only given embeddings (no labels).
Option 1: Wait for other branch to merge
That branch is now merged
@KevinMusgrave Right, the implementation seems pretty intuitive with the new merge. Happy to submit a pull request if nobody's working on it right now.
@codeandproduce Yes please do!
@KevinMusgrave
The function itself is easily done, but I'm not sure if this can fit at all into the current BaseMetricLossFunction
framework.
class VICRegLoss(BaseMetricLossFunction):
def __init__(
self,
invariance_lambda=25,
variance_mu=25,
covariance_v=1,
eps=1e-4,
**kwargs
):
super().__init__(**kwargs)
"""
The overall loss function is a weighted average of the invariance, variance and covariance terms:
L(Z, Z') = λs(Z, Z') + µ[v(Z) + v(Z')] + ν[c(Z) + c(Z')],
where λ, µ and ν are hyper-parameters controlling the importance of each term in the loss.
"""
self.invariance_lambda = invariance_lambda
self.variance_mu = variance_mu
self.covariance_v = covariance_v
self.eps = eps
def compute_loss(
self, embeddings, ref_emb
):
invariance_loss = self.invariance_loss(embeddings, ref_emb)
variance_loss = self.variance_loss(embeddings, ref_emb)
covariance_loss = self.covariance_loss(embeddings, ref_emb)
loss = self.invariance_lambda * invariance_loss + self.variance_mu * variance_loss + self.covariance_v * covariance_loss
return loss
# ...implementation of the above individual loss functions
The idea is that if you have a batch of images I
and some augmented version of those images I'
(e.g. a flip) and your model embeds these images E(I) = Z
,
loss_fn = VICRegLoss()
Z = E(I)
Z' = E(I')
loss = loss_fn(Z, Z')
It doesn't use any notion of labels
or indices_tuple
.
I was considering forcing this framework with something like this:
# assume Z.size() == (5, 255)
embeds = Z
ref_embeds = Z'
labels = [1,2,3,4,5]
loss = loss_fn(embeds, ref_embeds, labels, ref_labels=labels)
but this is very unnecessary.. it's not like this function could benefit from handling cases where ref_labels
is not exactly the same as labels
.
But the specific thing making my current implementation impossible is the following forward()
function in the BaseMetricLossFunction
class:
loss_dict = self.compute_loss(
embeddings, labels, indices_tuple, ref_emb, ref_labels
)
self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings)
return self.reducer(loss_dict, embeddings, labels)
since it expects a dict that looks something like:
return {
"loss": {
"losses": loss,
"indices": indices_tuple,
"reduction_type": "triplet",
}
}
and VICRegLoss
doesn't use any notion of indices
or reduction
.
I'm wondering if I should just write a new Base class with something like BaseSimpleLossFunction
that just directly returns the loss of the compute_loss
function of the child class with the purpose of being the most bare-bones and flexible superclass.
class BaseSimpleLossFunction:
def compute_loss(self, embeddings, labels, indices_tuple=None):
"""
This has to be implemented and is what actually computes the loss.
"""
raise NotImplementedError
def forward(
self, embeddings=None, labels=None, indices_tuple=None, ref_emb=None, ref_labels=None
):
loss = self.compute_loss(
embeddings, labels, indices_tuple, ref_emb, ref_labels
)
return loss
def zero_loss(self):
return {"losses": 0, "indices": None, "reduction_type": "already_reduced"}
def zero_losses(self):
return {loss_name: self.zero_loss() for loss_name in self.sub_loss_names()}
def _sub_loss_names(self):
return ["loss"]
def sub_loss_names(self):
return self._sub_loss_names() + self.all_regularization_loss_names()
On another note, maybe it's also relevant to consider how:
loss = self.invariance_lambda * invariance_loss + self.variance_mu * variance_loss + self.covariance_v * covariance_loss
This is basically the same thing as:
vic_regularization = self.variance_mu * variance_loss + self.covariance_v * covariance_loss
loss = mse(Z, Z') + vic_regularization
return loss
from the paper: • Invariance: the mean square distance between the embedding vectors. • Variance: a hinge loss to maintain the standard deviation (over a batch) of each variable of the embedding above a given threshold. This term forces the embedding vectors of samples within a batch to be different. • Covariance: a term that attracts the covariances (over a batch) between every pair of (centered) embedding variables towards zero. This term decorrelates the variables of each embedding and prevents an informational collapse in which the variables would vary together or be highly correlated.
Does this help form a solution at all?
I guess BaseRegularizer
is already kind of like your BaseSimpleLossFunction
. But it is missing a ref
argument from its forward function.
For now, you could try overriding the forward function and see if the design is reasonable:
from .base_regularizer import BaseRegularizer
from ..utils import common_functions as c_f
class VICRegLoss(BaseRegularizer):
def forward(self, x, y):
self.reset_stats()
loss_dict = self.compute_loss(x, y)
# Don't worry about x being passed into reducer here.
# I don't think it's even used in the currently available reducers.
return self.reducer(loss_dict, x, c_f.torch_arange_from_size(x))
def compute_loss(self, x, y):
invariance_loss = self.invariance_loss(embeddings, ref_emb)
variance_loss = self.variance_loss(embeddings, ref_emb)
covariance_loss = self.covariance_loss(embeddings, ref_emb)
loss = self.invariance_lambda * invariance_loss + self.variance_mu * variance_loss + self.covariance_v * covariance_loss
return {
"loss": {
"losses": loss,
"indices": None,
"reduction_type": "already_reduced",
}
}
It might be possible to keep the losses unreduced. For example, the mean square distance between vectors can be kept as per-pair (i.e. not passed into torch.mean).
The "losses", "indices", and "reduction_type" keys are a bit confusing. I think the reducers are overly complicated in general. Anyway, if you decide to keep the losses unreduced, you can return something like this:
return {
"invariance_loss": {
"losses": invariance_loss,
"indices": c_f.torch_arange_from_size(invariance_loss),
"reduction_type": "element",
},
# Repeat for the other 2 losses,
# They can be "already_reduced" if there is no good way of keeping them unreduced
}
# also override this function
def sub_loss_names(self):
return ["invariance_loss", ...]
But to keep things simple, I would start with the "already_reduced" version.
Re: the components of the VIC loss, you could write them as simple loss functions (i.e. extend torch.nn.Module
), and then use them inside of VICRegLoss.
Ultimately, if you find it makes more sense to extend torch.nn.Module
, then that's ok too.
Available in v1.1:
pip install pytorch-metric-learning==1.1.0
Hi, will there be an update with VICReg loss function? It is a pretty cool loss that does not need negative samples. https://arxiv.org/abs/2105.04906