kkirchheim / pytorch-ood

👽 Out-of-Distribution Detection with PyTorch
https://pytorch-ood.readthedocs.io/
Apache License 2.0
237 stars 23 forks source link

Addition of Angular Loss and Angle-Based Detector #59

Open YacineBelHadj opened 3 months ago

YacineBelHadj commented 3 months ago

Description:

I would like to propose the addition of a new loss function and detector to the pytorch-ood library: an Angular Loss function (e.g., ArcFace) and an Angle-Based Detector. These additions aim to enhance the discriminative capabilities of models for Out-of-Distribution (OOD) detection by leveraging angular information.

Motivation:

Angular loss functions, such as Additive Angular Margin Loss (ArcFace), have demonstrated significant improvements in discriminative power for tasks like face recognition by optimizing angular margins between classes. Applying similar principles to OOD detection could improve the model's ability to distinguish between in-distribution and out-of-distribution samples.

An Angle-Based Detector would utilize angular distances between feature representations and class centers to identify OOD samples, potentially providing a more robust method for OOD detection.

Proposed Changes:

Angular Loss Implementation:
     as  implemented in pytorch_metric_learning 

Angle-Based Detector:
    Develop a new detector class that calculates angular distances between input sample features and class centers.
    Use these angular distances as a basis for OOD detection.
kkirchheim commented 3 months ago

Hello, thank you for the proposal, it sounds reasonable. Are you aware of any publications that evaluated such angle-based methods particularly for OOD detection before?
Otherwise, we could add ArcFace, as the task seems related. Would you be able to implement the loss and the corresponding detector?

YacineBelHadj commented 3 months ago

Hi,

Thanks for the prompt response. Angular margin losses are indeed used in OOD detection in the context of machine monitoring through sound-spectrograms, as demonstrated in the DCASE challenge Task2 ( also check this reference: "Why do Angular Margin Losses Work Well for Semi-Supervised Anomalous Sound Detection").

Is it reasonable to add pytorch_metric_learning as a dependency? This package has implementations of different losses which could be beneficial for the project. or we re-implemented the most attractive "I believe" the sub-center-ArcFace Best regards,

kkirchheim commented 3 months ago

Is it reasonable to add pytorch_metric_learning as a dependency

I always try to avoid additional dependencies. However, we could make it an optional dependency, in the sense that the required libraries are loaded upon instantiation of the loss, and if not found, an exception is raised with a hint that the library is missing.

Re-implementing also seems like a good idea.

YacineBelHadj commented 3 months ago

OK perfect, then I will start working on it as soon as I can :)

YacineBelHadj commented 2 months ago

I started looking into the package: I am not convinced about the implementation of Mahalanobis detector. ` n_classes = len(classes) self.mu = torch.zeros(size=(n_classes, z.shape[-1]), device=device) self.cov = torch.zeros(size=(z.shape[-1], z.shape[-1]), device=device)

    for clazz in range(n_classes):
        idxs = y.eq(clazz)
        assert idxs.sum() != 0
        zs = z[idxs]
        self.mu[clazz] = zs.mean(dim=0)
        self.cov += (zs - self.mu[clazz]).T.mm(zs - self.mu[clazz]) 

    self.cov += torch.eye(self.cov.shape[0], device=self.cov.device) * 1e-6`

I thought we are supposed to compute a mean and cov for each class thus we should stack the mu's and cov's and the cov's are not supposed to be added on top of each other . Additionally for covariance calculation maybe using OAS from scikit learn could be a good idea : from sklearn.covariance import [OAS] ==> theoratically more robust

kkirchheim commented 2 months ago

The paper uses a "tied covariance" matrix (that is, a shared covariance matrix for all classes) and class conditional centers.

kkirchheim commented 2 months ago

It seems to me that you are passing the wrong arguments into the Mahalanobis detector. Please, consider reading the documentation here carefully and have a look at the examples here.

The model argument should map from inputs to features. Thus, you should pass nn.features instead of nn. When you pass your entire network nn, the mahalanbis-model will use the output of the network as "features".

The following test passes:

def test_mu_shape(self):
    number_classes = 5
    embedding_size = 10
    nn = ClassificationModel(num_inputs=128, num_outputs=number_classes, n_hidden=embedding_size)

    x = torch.randn(size=(50, 128))
    y = torch.randint(0, number_classes, (50,))
    dataset = TensorDataset(x, y)
    loader = DataLoader(dataset)

    model = Mahalanobis(nn.features)
    model.fit(loader, device="cpu")

    self.assertEqual(model.mu.shape[0], number_classes)
    self.assertEqual(model.mu.shape[1], embedding_size)
YacineBelHadj commented 2 months ago

Thanks :) . I just had this error when trying to substitute my implementation with yours. I will stop bothering you :)

I mistakenly deleted my comment.