sicara / easy-few-shot-learning

Ready-to-use code and tutorial notebooks to boost your way into few-shot learning for image classification.
MIT License
1.04k stars 141 forks source link

cosine_distance_to_prototypes() and l2_distance_to_prototypes() are falsely named #143

Open Y-T-G opened 6 months ago

Y-T-G commented 6 months ago

If the model is returning cosine distance:

https://github.com/sicara/easy-few-shot-learning/blob/8422b97155f6edd506e99fd5b83362ee36865f1e/easyfsl/methods/simple_shot.py#L29

Does that mean the lower the better?

ebennequin commented 6 months ago

From the code and docstring of the cosine_distance_to_prototypes() method:

    def cosine_distance_to_prototypes(self, samples) -> Tensor:
        """
        Compute prediction logits from their cosine distance to support set prototypes.
        Args:
            samples: features of the items to classify of shape (n_samples, feature_dimension)
        Returns:
            prediction logits of shape (n_samples, n_classes)
        """
        return (
            nn.functional.normalize(samples, dim=1)
            @ nn.functional.normalize(self.prototypes, dim=1).T
        )

The method actually doesn't return cosine distances but predictions logits equal to the cosine similarity, so the higher is actually the better.

Same logic with the other available "distance", which is actually logits as the opposite of the distance:

    def l2_distance_to_prototypes(self, samples: Tensor) -> Tensor:
        """
        Compute prediction logits from their euclidean distance to support set prototypes.
        Args:
            samples: features of the items to classify of shape (n_samples, feature_dimension)
        Returns:
            prediction logits of shape (n_samples, n_classes)
        """
        return -torch.cdist(samples, self.prototypes)

Calling the methods cosine_distance_to_prototypes() and l2_distance_to_prototypes() is a misleading naming. I am marking this as a much needed enhancement to the library.

Y-T-G commented 6 months ago

I see. That makes it clear. Thanks.