lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.92k stars 248 forks source link

Feature Request: Return all losses with multi-loss methods e.g., VICReg #1422

Open RylanSchaeffer opened 10 months ago

RylanSchaeffer commented 10 months ago

Some SSL methods have a total loss that is a weighted combination of several component losses. For instance, VICReg is a sum of three loss terms: invariance, variance, covariance.

It would be really nice if all the component losses were returned by the VICRegLoss function.

The minimal use case is logging/monitoring. For instance, if the VICReg loss becomes unstable, the user can't necessarily tell which of the three component losses is misbehaving.

RylanSchaeffer commented 10 months ago

Another example with multiple component losses is TiCo https://arxiv.org/abs/2206.10698

RylanSchaeffer commented 10 months ago

A third example with multiple component losses is HypersphereLoss https://docs.lightly.ai/self-supervised-learning/lightly.loss.html#lightly.loss.hypersphere_loss.HypersphereLoss

guarin commented 10 months ago

Yes, this is a common issue. We currently only return a single loss because it slightly simplifies the code and it allows you to exchange loss functions without having to make other code changes (you don't have to handle aggregation of the different parts).

For VICReg you can compute the components individually:

from lightly.loss.vicreg_loss import invariance_loss, variance_loss, covariance_loss

inv_loss = invariance_loss(x=z_a, y=z_b)
var_loss = 0.5 * (variance_loss(x=z_a) + variance_loss(x=z_b))
cov_loss = covariance_loss(x=z_a) + covariance_loss(x=z_b)
print(inv_loss, var_loss, cov_loss)
total_loss = 25.0 * inv_loss + 25.0 * var_loss + 1.0 * cov_loss

See also #1161

For TiCo and HypersphereLoss we should definitely add the possibility to calculate the components individually.

Out of curiosity, when logging the individual parts would you log them with or without the loss weights? For example in VICReg each loss part has a weight (lambda_param, mu_param, nu_param) that is then multiplied with the loss.

RylanSchaeffer commented 10 months ago

For VICReg you can compute the components individually:

I was previously implementing all this SSL stuff myself for a research project, then discovered Lightly and thought "Awesome! Why reimplement the wheel? I'll just switch to Lightly." I can of course compute the components individually, but I would prefer to either use Lightly or not, rather than mix and match.

The pattern I was using that works well is that each loss function should return a dictionary by default, and then one can access the combined loss e.g.:


    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Returns VICReg loss.

        Args:
            z_a:
                Tensor with shape (batch_size, ..., dim).
            z_b:
                Tensor with shape (batch_size, ..., dim).
        """
        assert (
            z_a.shape[0] > 1 and z_b.shape[0] > 1
        ), f"z_a and z_b must have batch size > 1 but found {z_a.shape[0]} and {z_b.shape[0]}"
        assert (
            z_a.shape == z_b.shape
        ), f"z_a and z_b must have same shape but found {z_a.shape} and {z_b.shape}."

        # invariance term of the loss
        inv_loss = invariance_loss(x=z_a, y=z_b)

        # gather all batches
        if self.gather_distributed and dist.is_initialized():
            world_size = dist.get_world_size()
            if world_size > 1:
                z_a = torch.cat(gather(z_a), dim=0)
                z_b = torch.cat(gather(z_b), dim=0)

        var_loss = 0.5 * (
            variance_loss(x=z_a, eps=self.eps) + variance_loss(x=z_b, eps=self.eps)
        )
        cov_loss = covariance_loss(x=z_a) + covariance_loss(x=z_b)

        loss = (
            self.lambda_param * inv_loss
            + self.mu_param * var_loss
            + self.nu_param * cov_loss
        )

        loss_results_dict = {
            "inv_loss": inv_loss,
            "var_loss": var_loss,
            "cov_loss": cov_loss,
            "total_loss": self.lambda_param * inv_loss + self.mu_param * var_loss + nu_param * cov_loss
        }

        return loss_results_dict

Out of curiosity, when logging the individual parts would you log them with or without the loss weights?

Without the weights, definitely.